import gradio as gr import os import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, AutoProcessor, MusicgenForConditionalGeneration, ) from scipy.io.wavfile import write from pydub import AudioSegment from dotenv import load_dotenv import tempfile import spaces from TTS.api import TTS # ----------------------------------------------------------- # Initialization & Environment Setup # ----------------------------------------------------------- load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # ----------------------------------------------------------- # Model Cache Management # ----------------------------------------------------------- LLAMA_PIPELINES = {} MUSICGEN_MODELS = {} TTS_MODELS = {} def get_llama_pipeline(model_id: str, token: str): if model_id in LLAMA_PIPELINES: return LLAMA_PIPELINES[model_id] tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) model = AutoModelForCausalLM.from_pretrained( model_id, use_auth_token=token, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, ) text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) LLAMA_PIPELINES[model_id] = text_pipeline return text_pipeline def get_musicgen_model(model_key: str = "facebook/musicgen-large"): if model_key in MUSICGEN_MODELS: return MUSICGEN_MODELS[model_key] model = MusicgenForConditionalGeneration.from_pretrained(model_key) processor = AutoProcessor.from_pretrained(model_key) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) MUSICGEN_MODELS[model_key] = (model, processor) return model, processor def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"): if model_name in TTS_MODELS: return TTS_MODELS[model_name] tts_model = TTS(model_name) TTS_MODELS[model_name] = tts_model return tts_model # ----------------------------------------------------------- # Core Functionality # ----------------------------------------------------------- @spaces.GPU(duration=100) def generate_script(user_prompt: str, model_id: str, token: str, duration: int): try: text_pipeline = get_llama_pipeline(model_id, token) system_prompt = f"""You are a professional audio producer creating {duration}-second content. Generate: 1. Voice script (clear and concise) 2. Sound design suggestions (specific effects) 3. Music style recommendations (genre, tempo)""" full_prompt = f"{system_prompt}\nClient brief: {user_prompt}\nOutput:" with torch.inference_mode(): result = text_pipeline( full_prompt, max_new_tokens=400, do_sample=True, temperature=0.7, top_p=0.9 ) generated_text = result[0]["generated_text"].split("Output:")[-1].strip() # Parse sections sections = { "Voice-Over Script:": "", "Sound Design Suggestions:": "", "Music Suggestions:": "" } current_section = None for line in generated_text.split('\n'): for section in sections: if section in line: current_section = section line = line.replace(section, '').strip() if current_section: sections[current_section] += line + '\n' return ( sections["Voice-Over Script:"].strip() or "No script generated", sections["Sound Design Suggestions:"].strip() or "No sound design suggestions", sections["Music Suggestions:"].strip() or "No music suggestions" ) except Exception as e: return f"Error: {str(e)}", "", "" @spaces.GPU(duration=100) def generate_voice(script: str, tts_model_name: str): try: if not script.strip(): return None tts_model = get_tts_model(tts_model_name) output_path = f"{tempfile.gettempdir()}/voice_temp.wav" tts_model.tts_to_file(text=script, file_path=output_path) return output_path except Exception as e: print(f"Voice generation error: {e}") return None @spaces.GPU(duration=100) def generate_music(prompt: str, audio_length: int): try: model, processor = get_musicgen_model() device = "cuda" if torch.cuda.is_available() else "cpu" inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(device) with torch.inference_mode(): outputs = model.generate(**inputs, max_new_tokens=audio_length) audio_data = outputs[0, 0].cpu().numpy() normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16") output_path = f"{tempfile.gettempdir()}/music_temp.wav" write(output_path, 44100, normalized_audio) return output_path except Exception as e: print(f"Music generation error: {e}") return None @spaces.GPU(duration=100) def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int): try: voice = AudioSegment.from_wav(voice_path) music = AudioSegment.from_wav(music_path) # Adjust music length if len(music) < len(voice): loops_needed = (len(voice) // len(music)) + 1 music = music * loops_needed music = music[:len(voice)] # Ducking effect if ducking: ducked_music = music - duck_level final_audio = ducked_music.overlay(voice) else: final_audio = music.overlay(voice) output_path = f"{tempfile.gettempdir()}/final_mix.wav" final_audio.export(output_path, format="wav") return output_path except Exception as e: print(f"Mixing error: {e}") return None # ----------------------------------------------------------- # Enhanced UI Components # ----------------------------------------------------------- custom_css = """ #main-container { max-width: 1200px; margin: 0 auto; padding: 20px; background: #f5f5f5; border-radius: 15px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } .header { text-align: center; padding: 2em; background: linear-gradient(135deg, #2b5876 0%, #4e4376 100%); color: white; border-radius: 15px; margin-bottom: 2em; } .tab-nav { background: none !important; border: none !important; } .tab-button { padding: 1em 2em !important; border-radius: 8px !important; margin: 0 5px !important; transition: all 0.3s ease !important; } .tab-button:hover { transform: translateY(-2px); box-shadow: 0 3px 6px rgba(0,0,0,0.1); } .dark-btn { background: linear-gradient(135deg, #434343 0%, #000000 100%) !important; color: white !important; border: none !important; padding: 12px 24px !important; border-radius: 8px !important; } .output-card { background: white !important; border-radius: 10px !important; padding: 20px !important; box-shadow: 0 2px 4px rgba(0,0,0,0.05) !important; } .progress-indicator { color: #666; font-style: italic; margin-top: 10px; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo: with gr.Column(elem_id="main-container"): # Header Section with gr.Column(elem_classes="header"): gr.Markdown(""" # 🎙️ AI Promo Studio **Professional Audio Production Suite Powered by AI** """) # Main Workflow Tabs with gr.Tabs(elem_classes="tab-nav"): # Script Generation with gr.Tab("📝 Script Design", elem_classes="tab-button"): with gr.Row(equal_height=False): with gr.Column(scale=2): gr.Markdown("### 🎯 Project Brief") user_prompt = gr.Textbox( label="Describe your promo concept", placeholder="e.g., 'An intense 30-second movie trailer intro with epic orchestral music and dramatic sound effects...'", lines=4 ) with gr.Row(): duration = gr.Slider( label="Duration (seconds)", minimum=15, maximum=120, step=15, value=30, interactive=True ) llama_model_id = gr.Dropdown( label="AI Model", choices=["meta-llama/Meta-Llama-3-8B-Instruct"], value="meta-llama/Meta-Llama-3-8B-Instruct", interactive=True ) generate_btn = gr.Button("Generate Script 🚀", elem_classes="dark-btn") with gr.Column(scale=1, elem_classes="output-card"): gr.Markdown("### 📄 Generated Content") script_output = gr.Textbox(label="Voice Script", lines=6) sound_design_output = gr.Textbox(label="Sound Design", lines=3) music_suggestion_output = gr.Textbox(label="Music Style", lines=3) # Voice Production with gr.Tab("🎙️ Voice Production", elem_classes="tab-button"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🔊 Voice Settings") tts_model = gr.Dropdown( label="Voice Model", choices=[ "tts_models/en/ljspeech/tacotron2-DDC", "tts_models/en/ljspeech/vits", "tts_models/en/sam/tacotron-DDC" ], value="tts_models/en/ljspeech/tacotron2-DDC", interactive=True ) with gr.Row(): voice_preview_btn = gr.Button("Preview Sample", elem_classes="dark-btn") voice_generate_btn = gr.Button("Generate Full Voiceover", elem_classes="dark-btn") with gr.Column(scale=1, elem_classes="output-card"): gr.Markdown("### 🎧 Voice Preview") voice_audio = gr.Audio( label="Generated Voice", interactive=False, waveform_options={"show_controls": True} ) # Music Production with gr.Tab("🎵 Music Design", elem_classes="tab-button"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🎹 Music Parameters") audio_length = gr.Slider( label="Generation Length", minimum=256, maximum=1024, step=64, value=512, info="Higher values = longer generation time" ) music_generate_btn = gr.Button("Generate Music Track", elem_classes="dark-btn") with gr.Column(scale=1, elem_classes="output-card"): gr.Markdown("### 🎶 Music Preview") music_output = gr.Audio( label="Generated Music", interactive=False, waveform_options={"show_controls": True} ) # Final Mix with gr.Tab("🔊 Final Mix", elem_classes="tab-button"): with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 🎚️ Mixing Console") ducking_enabled = gr.Checkbox( label="Enable Voice Ducking", value=True, info="Automatically lower music during voice segments" ) duck_level = gr.Slider( label="Ducking Intensity (dB)", minimum=3, maximum=20, step=1, value=10 ) mix_btn = gr.Button("Generate Final Mix", elem_classes="dark-btn") with gr.Column(scale=1, elem_classes="output-card"): gr.Markdown("### 🎧 Final Production") final_mix = gr.Audio( label="Mixed Output", interactive=False, waveform_options={"show_controls": True} ) # Footer with gr.Column(elem_classes="output-card"): gr.Markdown("""
Bils Imaging

Professional Audio Production Suite v2.1 © 2024 | Bils Imaging

""") # Event Handling generate_btn.click( generate_script, inputs=[user_prompt, llama_model_id, gr.Textbox(HF_TOKEN, visible=False), duration], outputs=[script_output, sound_design_output, music_suggestion_output] ) voice_generate_btn.click( generate_voice, inputs=[script_output, tts_model], outputs=voice_audio ) music_generate_btn.click( generate_music, inputs=[music_suggestion_output, audio_length], outputs=music_output ) mix_btn.click( blend_audio, inputs=[voice_audio, music_output, ducking_enabled, duck_level], outputs=final_mix ) if __name__ == "__main__": demo.launch(debug=True)