Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import uuid | |
import torch | |
import numpy as np | |
import gradio as gr | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration, | |
) | |
from scipy.io.wavfile import write | |
from pydub import AudioSegment | |
from dotenv import load_dotenv | |
import tempfile | |
import spaces | |
from TTS.api import TTS | |
# ----------------------------------------------------------- | |
# Initialization & Environment Setup | |
# ----------------------------------------------------------- | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# ----------------------------------------------------------- | |
# Model Cache Management | |
# ----------------------------------------------------------- | |
LLAMA_PIPELINES = {} | |
MUSICGEN_MODELS = {} | |
TTS_MODELS = {} | |
def get_llama_pipeline(model_id: str, token: str): | |
"""Load and cache the LLaMA text-generation pipeline.""" | |
if model_id in LLAMA_PIPELINES: | |
return LLAMA_PIPELINES[model_id] | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
use_auth_token=token, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
LLAMA_PIPELINES[model_id] = text_pipeline | |
return text_pipeline | |
def get_musicgen_model(model_key: str = "facebook/musicgen-large"): | |
"""Load and cache the MusicGen model and processor.""" | |
if model_key in MUSICGEN_MODELS: | |
return MUSICGEN_MODELS[model_key] | |
model = MusicgenForConditionalGeneration.from_pretrained(model_key) | |
processor = AutoProcessor.from_pretrained(model_key) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
MUSICGEN_MODELS[model_key] = (model, processor) | |
return model, processor | |
def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"): | |
"""Load and cache the TTS model.""" | |
if model_name in TTS_MODELS: | |
return TTS_MODELS[model_name] | |
tts_model = TTS(model_name) | |
TTS_MODELS[model_name] = tts_model | |
return tts_model | |
# ----------------------------------------------------------- | |
# Core Functionality | |
# ----------------------------------------------------------- | |
def generate_script(user_prompt: str, model_id: str, token: str, duration: int): | |
""" | |
Generate a professional promo script including a voice-over script, | |
sound design suggestions, and music recommendations. | |
""" | |
try: | |
text_pipeline = get_llama_pipeline(model_id, token) | |
# Updated prompt to instruct the model to output sections with explicit headers. | |
system_prompt = ( | |
f"You are a professional audio producer creating {duration}-second content. " | |
"Please generate the following three sections exactly as shown:\n\n" | |
"Voice-Over Script: [A clear and concise script for the voiceover.]\n" | |
"Sound Design Suggestions: [Specific ideas, effects, and ambience recommendations.]\n" | |
"Music Suggestions: [Recommendations for music style, genre, and tempo.]\n\n" | |
"Make sure each section starts with its header exactly." | |
) | |
full_prompt = f"{system_prompt}\nClient brief: {user_prompt}\nOutput:" | |
with torch.inference_mode(): | |
result = text_pipeline( | |
full_prompt, | |
max_new_tokens=400, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
generated_text = result[0]["generated_text"].split("Output:")[-1].strip() | |
# Parse the output into the three expected sections. | |
sections = { | |
"Voice-Over Script:": "", | |
"Sound Design Suggestions:": "", | |
"Music Suggestions:": "" | |
} | |
current_section = None | |
for line in generated_text.split('\n'): | |
for section in sections: | |
if section in line: | |
current_section = section | |
# Remove header from the line. | |
line = line.replace(section, '').strip() | |
break | |
if current_section: | |
sections[current_section] += line + '\n' | |
return ( | |
sections["Voice-Over Script:"].strip() or "No script generated", | |
sections["Sound Design Suggestions:"].strip() or "No sound design suggestions", | |
sections["Music Suggestions:"].strip() or "No music suggestions" | |
) | |
except Exception as e: | |
return f"Error: {str(e)}", "", "" | |
def generate_voice(script: str, tts_model_name: str): | |
""" | |
Generate full voice-over audio from the provided script using a TTS model. | |
""" | |
try: | |
if not script.strip(): | |
return None | |
tts_model = get_tts_model(tts_model_name) | |
# Create a unique temporary file name for the output. | |
output_path = os.path.join(tempfile.gettempdir(), f"voice_{uuid.uuid4().hex}.wav") | |
tts_model.tts_to_file(text=script, file_path=output_path) | |
return output_path | |
except Exception as e: | |
print(f"Voice generation error: {e}") | |
return None | |
def generate_voice_preview(script: str, tts_model_name: str): | |
""" | |
Generate a short preview of the voice-over by taking the first 100 words. | |
""" | |
try: | |
if not script.strip(): | |
return None | |
words = script.split() | |
preview_text = ' '.join(words[:100]) if len(words) > 100 else script | |
return generate_voice(preview_text, tts_model_name) | |
except Exception as e: | |
print(f"Voice preview error: {e}") | |
return None | |
def generate_music(prompt: str, audio_length: int): | |
""" | |
Generate music audio from a text prompt using the MusicGen model. | |
""" | |
try: | |
model, processor = get_musicgen_model() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
outputs = model.generate(**inputs, max_new_tokens=audio_length) | |
# Assuming outputs[0, 0] holds the generated audio waveform. | |
audio_data = outputs[0, 0].cpu().numpy() | |
# Prevent division by zero during normalization. | |
max_val = np.max(np.abs(audio_data)) | |
if max_val == 0: | |
normalized_audio = audio_data.astype("int16") | |
else: | |
normalized_audio = (audio_data / max_val * 32767).astype("int16") | |
output_path = os.path.join(tempfile.gettempdir(), f"music_{uuid.uuid4().hex}.wav") | |
write(output_path, 44100, normalized_audio) | |
return output_path | |
except Exception as e: | |
print(f"Music generation error: {e}") | |
return None | |
def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int): | |
""" | |
Blend the generated voice and music audio files. | |
If ducking is enabled, lower the music volume during the voice segments. | |
""" | |
try: | |
voice = AudioSegment.from_wav(voice_path) | |
music = AudioSegment.from_wav(music_path) | |
# Loop the music track if it's shorter than the voice track. | |
if len(music) < len(voice): | |
loops_needed = (len(voice) // len(music)) + 1 | |
music = music * loops_needed | |
music = music[:len(voice)] | |
if ducking: | |
ducked_music = music - duck_level | |
final_audio = ducked_music.overlay(voice) | |
else: | |
final_audio = music.overlay(voice) | |
output_path = os.path.join(tempfile.gettempdir(), f"final_mix_{uuid.uuid4().hex}.wav") | |
final_audio.export(output_path, format="wav") | |
return output_path | |
except Exception as e: | |
print(f"Mixing error: {e}") | |
return None | |
# ----------------------------------------------------------- | |
# Enhanced UI Components | |
# ----------------------------------------------------------- | |
custom_css = """ | |
#main-container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
background: #f0f9fb; | |
border-radius: 15px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.05); | |
} | |
.header { | |
text-align: center; | |
padding: 2em; | |
background: linear-gradient(135deg, #2a9d8f 0%, #457b9d 100%); | |
color: white; | |
border-radius: 15px; | |
margin-bottom: 2em; | |
border: 1px solid #264653; | |
} | |
.tab-nav { | |
background: none !important; | |
border: none !important; | |
} | |
.tab-button { | |
padding: 1em 2em !important; | |
border-radius: 8px !important; | |
margin: 0 5px !important; | |
transition: all 0.3s ease !important; | |
background: #e9f5f4 !important; | |
border: 1px solid #a8dadc !important; | |
color: #1d3557 !important; | |
} | |
.tab-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 3px 6px rgba(42,157,143,0.2); | |
background: #caf0f8 !important; | |
} | |
.dark-btn { | |
background: linear-gradient(135deg, #457b9d 0%, #2a9d8f 100%) !important; | |
color: white !important; | |
border: none !important; | |
padding: 12px 24px !important; | |
border-radius: 8px !important; | |
transition: transform 0.2s ease !important; | |
} | |
.dark-btn:hover { | |
transform: scale(1.02); | |
box-shadow: 0 3px 8px rgba(42,157,143,0.3); | |
} | |
.output-card { | |
background: #f8fbfe !important; | |
border-radius: 10px !important; | |
padding: 20px !important; | |
box-shadow: 0 2px 4px rgba(69,123,157,0.1) !important; | |
border: 1px solid #e2e8f0; | |
} | |
.progress-indicator { | |
color: #457b9d; | |
font-style: italic; | |
margin-top: 10px; | |
} | |
/* Additional Color Elements */ | |
h1, h2, h3 { | |
color: #1d3557 !important; | |
} | |
audio { | |
border: 1px solid #a8dadc !important; | |
border-radius: 8px !important; | |
} | |
.slider-handle { | |
background: #2a9d8f !important; | |
} | |
""" | |
with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo: | |
with gr.Column(elem_id="main-container"): | |
# Header Section | |
with gr.Column(elem_classes="header"): | |
gr.Markdown(""" | |
# ποΈ AI Promo Studio | |
**Professional Audio Production Suite Powered by AI** | |
""") | |
# Main Workflow Tabs | |
with gr.Tabs(elem_classes="tab-nav"): | |
# Script Generation Tab | |
with gr.Tab("π Script Design", elem_classes="tab-button"): | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=2): | |
gr.Markdown("### π― Project Brief") | |
user_prompt = gr.Textbox( | |
label="Describe your promo concept", | |
placeholder="e.g., 'An intense 30-second movie trailer intro with epic orchestral music and dramatic sound effects...'", | |
lines=4 | |
) | |
with gr.Row(): | |
duration = gr.Slider( | |
label="Duration (seconds)", | |
minimum=15, | |
maximum=120, | |
step=15, | |
value=30, | |
interactive=True | |
) | |
llama_model_id = gr.Dropdown( | |
label="AI Model", | |
choices=["meta-llama/Meta-Llama-3-8B-Instruct"], | |
value="meta-llama/Meta-Llama-3-8B-Instruct", | |
interactive=True | |
) | |
generate_btn = gr.Button("Generate Script π", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π Generated Content") | |
script_output = gr.Textbox(label="Voice Script", lines=6) | |
sound_design_output = gr.Textbox(label="Sound Design", lines=3) | |
music_suggestion_output = gr.Textbox(label="Music Style", lines=3) | |
# Voice Production Tab | |
with gr.Tab("ποΈ Voice Production", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### π Voice Settings") | |
tts_model = gr.Dropdown( | |
label="Voice Model", | |
choices=[ | |
"tts_models/en/ljspeech/tacotron2-DDC", | |
"tts_models/en/ljspeech/vits", | |
"tts_models/en/sam/tacotron-DDC" | |
], | |
value="tts_models/en/ljspeech/tacotron2-DDC", | |
interactive=True | |
) | |
with gr.Row(): | |
voice_preview_btn = gr.Button("Preview Sample", elem_classes="dark-btn") | |
voice_generate_btn = gr.Button("Generate Full Voiceover", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π§ Voice Preview") | |
voice_audio = gr.Audio( | |
label="Generated Voice", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Music Production Tab | |
with gr.Tab("π΅ Music Design", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### πΉ Music Parameters") | |
audio_length = gr.Slider( | |
label="Generation Length", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=512, | |
info="Higher values = longer generation time" | |
) | |
music_generate_btn = gr.Button("Generate Music Track", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### πΆ Music Preview") | |
music_output = gr.Audio( | |
label="Generated Music", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Final Mix Tab | |
with gr.Tab("π Final Mix", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### ποΈ Mixing Console") | |
ducking_enabled = gr.Checkbox( | |
label="Enable Voice Ducking", | |
value=True, | |
info="Automatically lower music during voice segments" | |
) | |
duck_level = gr.Slider( | |
label="Ducking Intensity (dB)", | |
minimum=3, | |
maximum=20, | |
step=1, | |
value=10 | |
) | |
mix_btn = gr.Button("Generate Final Mix", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π§ Final Production") | |
final_mix = gr.Audio( | |
label="Mixed Output", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Footer Section | |
with gr.Column(elem_classes="output-card"): | |
gr.Markdown(""" | |
<div style="text-align: center; padding: 1.5em 0;"> | |
<a href="https://bilsimaging.com" target="_blank"> | |
<img src="https://bilsimaging.com/logo.png" alt="Bils Imaging" style="height: 35px; margin-right: 15px;"> | |
</a> | |
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold"> | |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" /> | |
</a> | |
</div> | |
<p style="text-align: center; color: #666; font-size: 0.9em;"> | |
Professional Audio Production Suite v2.1 Β© 2024 | Bils Imaging | |
</p> | |
""") | |
# ----------------------------------------------------------- | |
# Event Handling | |
# ----------------------------------------------------------- | |
# Hidden textbox for HF_TOKEN (its value is set via the environment variable). | |
hf_token_hidden = gr.Textbox(value=HF_TOKEN, visible=False) | |
generate_btn.click( | |
generate_script, | |
inputs=[user_prompt, llama_model_id, hf_token_hidden, duration], | |
outputs=[script_output, sound_design_output, music_suggestion_output] | |
) | |
# Voice preview: generates a trimmed version of the script. | |
voice_preview_btn.click( | |
generate_voice_preview, | |
inputs=[script_output, tts_model], | |
outputs=voice_audio | |
) | |
# Full voice generation using the complete script. | |
voice_generate_btn.click( | |
generate_voice, | |
inputs=[script_output, tts_model], | |
outputs=voice_audio | |
) | |
music_generate_btn.click( | |
generate_music, | |
inputs=[music_suggestion_output, audio_length], | |
outputs=music_output | |
) | |
mix_btn.click( | |
blend_audio, | |
inputs=[voice_audio, music_output, ducking_enabled, duck_level], | |
outputs=final_mix | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |