Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration, | |
) | |
from scipy.io.wavfile import write | |
from pydub import AudioSegment | |
from dotenv import load_dotenv | |
import tempfile | |
import spaces | |
from TTS.api import TTS | |
# ----------------------------------------------------------- | |
# Initialization & Environment Setup | |
# ----------------------------------------------------------- | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# ----------------------------------------------------------- | |
# Model Cache Management | |
# ----------------------------------------------------------- | |
LLAMA_PIPELINES = {} | |
MUSICGEN_MODELS = {} | |
TTS_MODELS = {} | |
def get_llama_pipeline(model_id: str, token: str): | |
if model_id in LLAMA_PIPELINES: | |
return LLAMA_PIPELINES[model_id] | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
use_auth_token=token, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
) | |
text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
LLAMA_PIPELINES[model_id] = text_pipeline | |
return text_pipeline | |
def get_musicgen_model(model_key: str = "facebook/musicgen-large"): | |
if model_key in MUSICGEN_MODELS: | |
return MUSICGEN_MODELS[model_key] | |
model = MusicgenForConditionalGeneration.from_pretrained(model_key) | |
processor = AutoProcessor.from_pretrained(model_key) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
MUSICGEN_MODELS[model_key] = (model, processor) | |
return model, processor | |
def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"): | |
if model_name in TTS_MODELS: | |
return TTS_MODELS[model_name] | |
tts_model = TTS(model_name) | |
TTS_MODELS[model_name] = tts_model | |
return tts_model | |
# ----------------------------------------------------------- | |
# Core Functionality | |
# ----------------------------------------------------------- | |
def generate_script(user_prompt: str, model_id: str, token: str, duration: int): | |
try: | |
text_pipeline = get_llama_pipeline(model_id, token) | |
system_prompt = f"""You are a professional audio producer creating {duration}-second content. Generate: | |
1. Voice script (clear and concise) | |
2. Sound design suggestions (specific effects) | |
3. Music style recommendations (genre, tempo)""" | |
full_prompt = f"{system_prompt}\nClient brief: {user_prompt}\nOutput:" | |
with torch.inference_mode(): | |
result = text_pipeline( | |
full_prompt, | |
max_new_tokens=400, | |
do_sample=True, | |
temperature=0.7, | |
top_p=0.9 | |
) | |
generated_text = result[0]["generated_text"].split("Output:")[-1].strip() | |
# Parse sections | |
sections = { | |
"Voice-Over Script:": "", | |
"Sound Design Suggestions:": "", | |
"Music Suggestions:": "" | |
} | |
current_section = None | |
for line in generated_text.split('\n'): | |
for section in sections: | |
if section in line: | |
current_section = section | |
line = line.replace(section, '').strip() | |
if current_section: | |
sections[current_section] += line + '\n' | |
return ( | |
sections["Voice-Over Script:"].strip() or "No script generated", | |
sections["Sound Design Suggestions:"].strip() or "No sound design suggestions", | |
sections["Music Suggestions:"].strip() or "No music suggestions" | |
) | |
except Exception as e: | |
return f"Error: {str(e)}", "", "" | |
def generate_voice(script: str, tts_model_name: str): | |
try: | |
if not script.strip(): | |
return None | |
tts_model = get_tts_model(tts_model_name) | |
output_path = f"{tempfile.gettempdir()}/voice_temp.wav" | |
tts_model.tts_to_file(text=script, file_path=output_path) | |
return output_path | |
except Exception as e: | |
print(f"Voice generation error: {e}") | |
return None | |
def generate_music(prompt: str, audio_length: int): | |
try: | |
model, processor = get_musicgen_model() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(device) | |
with torch.inference_mode(): | |
outputs = model.generate(**inputs, max_new_tokens=audio_length) | |
audio_data = outputs[0, 0].cpu().numpy() | |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16") | |
output_path = f"{tempfile.gettempdir()}/music_temp.wav" | |
write(output_path, 44100, normalized_audio) | |
return output_path | |
except Exception as e: | |
print(f"Music generation error: {e}") | |
return None | |
def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int): | |
try: | |
voice = AudioSegment.from_wav(voice_path) | |
music = AudioSegment.from_wav(music_path) | |
# Adjust music length | |
if len(music) < len(voice): | |
loops_needed = (len(voice) // len(music)) + 1 | |
music = music * loops_needed | |
music = music[:len(voice)] | |
# Ducking effect | |
if ducking: | |
ducked_music = music - duck_level | |
final_audio = ducked_music.overlay(voice) | |
else: | |
final_audio = music.overlay(voice) | |
output_path = f"{tempfile.gettempdir()}/final_mix.wav" | |
final_audio.export(output_path, format="wav") | |
return output_path | |
except Exception as e: | |
print(f"Mixing error: {e}") | |
return None | |
# ----------------------------------------------------------- | |
# Enhanced UI Components | |
# ----------------------------------------------------------- | |
custom_css = """ | |
#main-container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
background: #f5f5f5; | |
border-radius: 15px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1); | |
} | |
.header { | |
text-align: center; | |
padding: 2em; | |
background: linear-gradient(135deg, #2b5876 0%, #4e4376 100%); | |
color: white; | |
border-radius: 15px; | |
margin-bottom: 2em; | |
} | |
.tab-nav { | |
background: none !important; | |
border: none !important; | |
} | |
.tab-button { | |
padding: 1em 2em !important; | |
border-radius: 8px !important; | |
margin: 0 5px !important; | |
transition: all 0.3s ease !important; | |
} | |
.tab-button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 3px 6px rgba(0,0,0,0.1); | |
} | |
.dark-btn { | |
background: linear-gradient(135deg, #434343 0%, #000000 100%) !important; | |
color: white !important; | |
border: none !important; | |
padding: 12px 24px !important; | |
border-radius: 8px !important; | |
} | |
.output-card { | |
background: white !important; | |
border-radius: 10px !important; | |
padding: 20px !important; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; | |
} | |
.progress-indicator { | |
color: #666; | |
font-style: italic; | |
margin-top: 10px; | |
} | |
""" | |
with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo: | |
with gr.Column(elem_id="main-container"): | |
# Header Section | |
with gr.Column(elem_classes="header"): | |
gr.Markdown(""" | |
# ποΈ AI Promo Studio | |
**Professional Audio Production Suite Powered by AI** | |
""") | |
# Main Workflow Tabs | |
with gr.Tabs(elem_classes="tab-nav"): | |
# Script Generation | |
with gr.Tab("π Script Design", elem_classes="tab-button"): | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=2): | |
gr.Markdown("### π― Project Brief") | |
user_prompt = gr.Textbox( | |
label="Describe your promo concept", | |
placeholder="e.g., 'An intense 30-second movie trailer intro with epic orchestral music and dramatic sound effects...'", | |
lines=4 | |
) | |
with gr.Row(): | |
duration = gr.Slider( | |
label="Duration (seconds)", | |
minimum=15, | |
maximum=120, | |
step=15, | |
value=30, | |
interactive=True | |
) | |
llama_model_id = gr.Dropdown( | |
label="AI Model", | |
choices=["meta-llama/Meta-Llama-3-8B-Instruct"], | |
value="meta-llama/Meta-Llama-3-8B-Instruct", | |
interactive=True | |
) | |
generate_btn = gr.Button("Generate Script π", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π Generated Content") | |
script_output = gr.Textbox(label="Voice Script", lines=6) | |
sound_design_output = gr.Textbox(label="Sound Design", lines=3) | |
music_suggestion_output = gr.Textbox(label="Music Style", lines=3) | |
# Voice Production | |
with gr.Tab("ποΈ Voice Production", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### π Voice Settings") | |
tts_model = gr.Dropdown( | |
label="Voice Model", | |
choices=[ | |
"tts_models/en/ljspeech/tacotron2-DDC", | |
"tts_models/en/ljspeech/vits", | |
"tts_models/en/sam/tacotron-DDC" | |
], | |
value="tts_models/en/ljspeech/tacotron2-DDC", | |
interactive=True | |
) | |
with gr.Row(): | |
voice_preview_btn = gr.Button("Preview Sample", elem_classes="dark-btn") | |
voice_generate_btn = gr.Button("Generate Full Voiceover", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π§ Voice Preview") | |
voice_audio = gr.Audio( | |
label="Generated Voice", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Music Production | |
with gr.Tab("π΅ Music Design", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### πΉ Music Parameters") | |
audio_length = gr.Slider( | |
label="Generation Length", | |
minimum=256, | |
maximum=1024, | |
step=64, | |
value=512, | |
info="Higher values = longer generation time" | |
) | |
music_generate_btn = gr.Button("Generate Music Track", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### πΆ Music Preview") | |
music_output = gr.Audio( | |
label="Generated Music", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Final Mix | |
with gr.Tab("π Final Mix", elem_classes="tab-button"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### ποΈ Mixing Console") | |
ducking_enabled = gr.Checkbox( | |
label="Enable Voice Ducking", | |
value=True, | |
info="Automatically lower music during voice segments" | |
) | |
duck_level = gr.Slider( | |
label="Ducking Intensity (dB)", | |
minimum=3, | |
maximum=20, | |
step=1, | |
value=10 | |
) | |
mix_btn = gr.Button("Generate Final Mix", elem_classes="dark-btn") | |
with gr.Column(scale=1, elem_classes="output-card"): | |
gr.Markdown("### π§ Final Production") | |
final_mix = gr.Audio( | |
label="Mixed Output", | |
interactive=False, | |
waveform_options={"show_controls": True} | |
) | |
# Footer | |
with gr.Column(elem_classes="output-card"): | |
gr.Markdown(""" | |
<div style="text-align: center; padding: 1.5em 0;"> | |
<a href="https://bilsimaging.com" target="_blank"> | |
<img src="https://bilsimaging.com/logo.png" alt="Bils Imaging" style="height: 35px; margin-right: 15px;"> | |
</a> | |
<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold"> | |
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" /> | |
</a> | |
</div> | |
<p style="text-align: center; color: #666; font-size: 0.9em;"> | |
Professional Audio Production Suite v2.1 Β© 2024 | Bils Imaging | |
</p> | |
""") | |
# Event Handling | |
generate_btn.click( | |
generate_script, | |
inputs=[user_prompt, llama_model_id, gr.Textbox(HF_TOKEN, visible=False), duration], | |
outputs=[script_output, sound_design_output, music_suggestion_output] | |
) | |
voice_generate_btn.click( | |
generate_voice, | |
inputs=[script_output, tts_model], | |
outputs=voice_audio | |
) | |
music_generate_btn.click( | |
generate_music, | |
inputs=[music_suggestion_output, audio_length], | |
outputs=music_output | |
) | |
mix_btn.click( | |
blend_audio, | |
inputs=[voice_audio, music_output, ducking_enabled, duck_level], | |
outputs=final_mix | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) |