Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import scipy.io.wavfile | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration | |
) | |
# --------------------------------------------------------------------- | |
# Page Configuration | |
# --------------------------------------------------------------------- | |
st.set_page_config( | |
page_icon="🎧", | |
layout="wide", | |
page_title="Radio Imaging Audio Generator - Llama & MusicGen", | |
initial_sidebar_state="expanded", | |
) | |
# --------------------------------------------------------------------- | |
# Custom CSS for a Vibrant UI | |
# --------------------------------------------------------------------- | |
CUSTOM_CSS = """ | |
<style> | |
body { | |
background-color: #F8FBFE; | |
color: #1F2937; | |
font-family: 'Segoe UI', Tahoma, sans-serif; | |
} | |
h1, h2, h3, h4, h5, h6 { | |
color: #3B82F6; | |
} | |
.stButton>button { | |
background-color: #3B82F6 !important; | |
color: #FFFFFF !important; | |
border-radius: 8px !important; | |
font-size: 16px !important; | |
} | |
.sidebar .sidebar-content { | |
background: #E0F2FE; | |
} | |
.material-card { | |
border: 1px solid #D1D5DB; | |
border-radius: 8px; | |
padding: 1rem; | |
margin-bottom: 1rem; | |
background-color: #ffffff; | |
} | |
.footer-note { | |
text-align: center; | |
opacity: 0.6; | |
font-size: 14px; | |
margin-top: 30px; | |
} | |
</style> | |
""" | |
st.markdown(CUSTOM_CSS, unsafe_allow_html=True) | |
# --------------------------------------------------------------------- | |
# Header Section | |
# --------------------------------------------------------------------- | |
st.markdown( | |
""" | |
<h1>Radio Imaging Audio Generator <span style="font-size: 24px; color: #F59E0B;">(Beta)</span></h1> | |
<p style='font-size:18px;'> | |
Generate custom radio imaging audio, ads, and promo tracks with Llama & MusicGen! | |
</p> | |
""", | |
unsafe_allow_html=True | |
) | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# Instructions Section in an Expander | |
# --------------------------------------------------------------------- | |
with st.expander("📘 How to Use This Web App"): | |
st.markdown( | |
""" | |
1. **Enter your prompt**: Describe the type of audio you need (e.g., an energetic 15-second jingle for a pop radio promo). | |
2. **Generate Description**: Let Llama 2 (or another open-source model) refine your prompt into a creative script. | |
3. **Generate Audio**: Pass that script to MusicGen to get a custom audio file. | |
4. **Playback & Download**: Listen to your new track and download it for further editing. | |
**Tips**: | |
- Keep descriptions short & specific for best results. | |
- If the Llama model is too large, switch to a smaller open-source model or try a GPU-based environment. | |
- If you see errors about model permissions, ensure you’ve accepted the license on Hugging Face. | |
""" | |
) | |
# --------------------------------------------------------------------- | |
# Sidebar: Model Selection & Options | |
# --------------------------------------------------------------------- | |
with st.sidebar: | |
st.header("🔧 Model Config") | |
# Llama 2 chat model from Hugging Face | |
llama_model_id = st.text_input( | |
"Llama 2 Model ID on Hugging Face", | |
value="meta-llama/Llama-2-7b-chat-hf", | |
help="For example: meta-llama/Llama-2-7b-chat-hf (requires license acceptance)." | |
) | |
device_option = st.selectbox( | |
"Hardware Device", | |
["auto", "cpu"], | |
help="If running locally with a GPU, choose 'auto'. If you only have a CPU, pick 'cpu'." | |
) | |
# --------------------------------------------------------------------- | |
# Prompt Input | |
# --------------------------------------------------------------------- | |
st.markdown("## ✍🏻 Write Your Brief / Concept") | |
prompt = st.text_area( | |
"Describe the radio imaging or jingle you want to create. Include style, mood, duration, etc.", | |
placeholder="e.g. 'An energetic 15-second pop jingle for a morning radio show, upbeat and fun...'" | |
) | |
# --------------------------------------------------------------------- | |
# Text Generation with Llama | |
# --------------------------------------------------------------------- | |
def load_llama_pipeline(model_id: str, device: str): | |
""" | |
Load the Llama or other open-source model as a text-generation pipeline. | |
The user must have accepted the license for certain models like Llama 2. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "auto" else torch.float32, | |
device_map=device | |
) | |
gen_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map=device | |
) | |
return gen_pipeline | |
def generate_description(user_prompt: str, pipeline_gen): | |
""" | |
Use the pipeline to create a refined description for MusicGen. | |
""" | |
# Instruction format for Llama 2 chat | |
# or simpler prompt if it's not a chat model | |
system_prompt = ( | |
"You are a helpful assistant specialized in creative advertising scripts and radio imaging. " | |
"Refine the user's short concept into a more detailed, creative script. " | |
"Keep it concise, but highlight any relevant tone, instruments, or style to guide music generation." | |
) | |
# We'll feed a combined prompt | |
combined_prompt = f"{system_prompt}\nUser request: {user_prompt}\nYour refined script:" | |
# Generate text | |
result = pipeline_gen( | |
combined_prompt, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.7 | |
) | |
# Extract generated text (some models output extra tokens or the entire prompt again) | |
generated_text = result[0]["generated_text"] | |
# Attempt to cut out the system prompt if it reappears | |
# Just a heuristic: find the last occurrence of "script:" or any relevant marker | |
if "script:" in generated_text.lower(): | |
generated_text = generated_text.split("script:")[-1].strip() | |
# Optional: add a sign-off or credit line | |
generated_text += "\n\n(Generated by Radio Imaging Audio Generator - Llama Edition)" | |
return generated_text | |
# Button: Generate Description | |
if st.button("📄 Refine Description with Llama"): | |
if not prompt.strip(): | |
st.error("Please provide a brief concept before generating a description.") | |
else: | |
with st.spinner("Generating a refined description..."): | |
try: | |
pipeline_llama = load_llama_pipeline(llama_model_id, device_option) | |
refined_text = generate_description(prompt, pipeline_llama) | |
st.session_state['refined_prompt'] = refined_text | |
st.success("Description successfully refined!") | |
st.write(refined_text) | |
st.download_button( | |
"📥 Download Description", | |
refined_text, | |
file_name="refined_description.txt" | |
) | |
except Exception as e: | |
st.error(f"Error while generating with Llama: {e}") | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# MusicGen: Generate Audio | |
# --------------------------------------------------------------------- | |
def load_musicgen_model(): | |
"""Load and cache the MusicGen model and processor.""" | |
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
return mg_model, mg_processor | |
if st.button("▶ Generate Audio with MusicGen"): | |
if 'refined_prompt' not in st.session_state or not st.session_state['refined_prompt']: | |
st.error("Please generate or have a refined description first.") | |
else: | |
descriptive_text = st.session_state['refined_prompt'] | |
with st.spinner("Generating your audio... This can take a moment."): | |
try: | |
musicgen_model, processor = load_musicgen_model() | |
# Use the refined prompt as input | |
inputs = processor( | |
text=[descriptive_text], | |
padding=True, | |
return_tensors="pt" | |
) | |
audio_values = musicgen_model.generate(**inputs, max_new_tokens=512) | |
sampling_rate = musicgen_model.config.audio_encoder.sampling_rate | |
# Save & display the audio | |
audio_filename = "radio_imaging_output.wav" | |
scipy.io.wavfile.write( | |
audio_filename, | |
rate=sampling_rate, | |
data=audio_values[0, 0].numpy() | |
) | |
st.success("Audio successfully generated!") | |
st.audio(audio_filename) | |
except Exception as e: | |
st.error(f"Error while generating audio: {e}") | |
# --------------------------------------------------------------------- | |
# Footer Section | |
# --------------------------------------------------------------------- | |
st.markdown("---") | |
st.markdown( | |
"<div class='footer-note'>" | |
"✅ Built with Llama 2 & MusicGen · " | |
"Created for radio imaging producers · " | |
"Feedback welcome at <a href='https://bilsimaging.com' target='_blank'>Bilsimaging</a>!" | |
"</div>", | |
unsafe_allow_html=True | |
) | |
# Hide Streamlit's default menu and footer if you wish | |
st.markdown("<style>#MainMenu {visibility: hidden;} footer {visibility: hidden;}</style>", unsafe_allow_html=True) | |