AIPromoStudio / app.py
Bils's picture
Update app.py
db46bfb verified
raw
history blame
9.64 kB
import streamlit as st
import torch
import scipy.io.wavfile
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
# ---------------------------------------------------------------------
# Page Configuration
# ---------------------------------------------------------------------
st.set_page_config(
page_icon="🎧",
layout="wide",
page_title="Radio Imaging Audio Generator - Llama & MusicGen",
initial_sidebar_state="expanded",
)
# ---------------------------------------------------------------------
# Custom CSS for a Vibrant UI
# ---------------------------------------------------------------------
CUSTOM_CSS = """
<style>
body {
background-color: #F8FBFE;
color: #1F2937;
font-family: 'Segoe UI', Tahoma, sans-serif;
}
h1, h2, h3, h4, h5, h6 {
color: #3B82F6;
}
.stButton>button {
background-color: #3B82F6 !important;
color: #FFFFFF !important;
border-radius: 8px !important;
font-size: 16px !important;
}
.sidebar .sidebar-content {
background: #E0F2FE;
}
.material-card {
border: 1px solid #D1D5DB;
border-radius: 8px;
padding: 1rem;
margin-bottom: 1rem;
background-color: #ffffff;
}
.footer-note {
text-align: center;
opacity: 0.6;
font-size: 14px;
margin-top: 30px;
}
</style>
"""
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
# ---------------------------------------------------------------------
# Header Section
# ---------------------------------------------------------------------
st.markdown(
"""
<h1>Radio Imaging Audio Generator <span style="font-size: 24px; color: #F59E0B;">(Beta)</span></h1>
<p style='font-size:18px;'>
Generate custom radio imaging audio, ads, and promo tracks with Llama & MusicGen!
</p>
""",
unsafe_allow_html=True
)
st.markdown("---")
# ---------------------------------------------------------------------
# Instructions Section in an Expander
# ---------------------------------------------------------------------
with st.expander("📘 How to Use This Web App"):
st.markdown(
"""
1. **Enter your prompt**: Describe the type of audio you need (e.g., an energetic 15-second jingle for a pop radio promo).
2. **Generate Description**: Let Llama 2 (or another open-source model) refine your prompt into a creative script.
3. **Generate Audio**: Pass that script to MusicGen to get a custom audio file.
4. **Playback & Download**: Listen to your new track and download it for further editing.
**Tips**:
- Keep descriptions short & specific for best results.
- If the Llama model is too large, switch to a smaller open-source model or try a GPU-based environment.
- If you see errors about model permissions, ensure you’ve accepted the license on Hugging Face.
"""
)
# ---------------------------------------------------------------------
# Sidebar: Model Selection & Options
# ---------------------------------------------------------------------
with st.sidebar:
st.header("🔧 Model Config")
# Llama 2 chat model from Hugging Face
llama_model_id = st.text_input(
"Llama 2 Model ID on Hugging Face",
value="meta-llama/Llama-2-7b-chat-hf",
help="For example: meta-llama/Llama-2-7b-chat-hf (requires license acceptance)."
)
device_option = st.selectbox(
"Hardware Device",
["auto", "cpu"],
help="If running locally with a GPU, choose 'auto'. If you only have a CPU, pick 'cpu'."
)
# ---------------------------------------------------------------------
# Prompt Input
# ---------------------------------------------------------------------
st.markdown("## ✍🏻 Write Your Brief / Concept")
prompt = st.text_area(
"Describe the radio imaging or jingle you want to create. Include style, mood, duration, etc.",
placeholder="e.g. 'An energetic 15-second pop jingle for a morning radio show, upbeat and fun...'"
)
# ---------------------------------------------------------------------
# Text Generation with Llama
# ---------------------------------------------------------------------
@st.cache_resource
def load_llama_pipeline(model_id: str, device: str):
"""
Load the Llama or other open-source model as a text-generation pipeline.
The user must have accepted the license for certain models like Llama 2.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "auto" else torch.float32,
device_map=device
)
gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=device
)
return gen_pipeline
def generate_description(user_prompt: str, pipeline_gen):
"""
Use the pipeline to create a refined description for MusicGen.
"""
# Instruction format for Llama 2 chat
# or simpler prompt if it's not a chat model
system_prompt = (
"You are a helpful assistant specialized in creative advertising scripts and radio imaging. "
"Refine the user's short concept into a more detailed, creative script. "
"Keep it concise, but highlight any relevant tone, instruments, or style to guide music generation."
)
# We'll feed a combined prompt
combined_prompt = f"{system_prompt}\nUser request: {user_prompt}\nYour refined script:"
# Generate text
result = pipeline_gen(
combined_prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.7
)
# Extract generated text (some models output extra tokens or the entire prompt again)
generated_text = result[0]["generated_text"]
# Attempt to cut out the system prompt if it reappears
# Just a heuristic: find the last occurrence of "script:" or any relevant marker
if "script:" in generated_text.lower():
generated_text = generated_text.split("script:")[-1].strip()
# Optional: add a sign-off or credit line
generated_text += "\n\n(Generated by Radio Imaging Audio Generator - Llama Edition)"
return generated_text
# Button: Generate Description
if st.button("📄 Refine Description with Llama"):
if not prompt.strip():
st.error("Please provide a brief concept before generating a description.")
else:
with st.spinner("Generating a refined description..."):
try:
pipeline_llama = load_llama_pipeline(llama_model_id, device_option)
refined_text = generate_description(prompt, pipeline_llama)
st.session_state['refined_prompt'] = refined_text
st.success("Description successfully refined!")
st.write(refined_text)
st.download_button(
"📥 Download Description",
refined_text,
file_name="refined_description.txt"
)
except Exception as e:
st.error(f"Error while generating with Llama: {e}")
st.markdown("---")
# ---------------------------------------------------------------------
# MusicGen: Generate Audio
# ---------------------------------------------------------------------
@st.cache_resource
def load_musicgen_model():
"""Load and cache the MusicGen model and processor."""
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return mg_model, mg_processor
if st.button("▶ Generate Audio with MusicGen"):
if 'refined_prompt' not in st.session_state or not st.session_state['refined_prompt']:
st.error("Please generate or have a refined description first.")
else:
descriptive_text = st.session_state['refined_prompt']
with st.spinner("Generating your audio... This can take a moment."):
try:
musicgen_model, processor = load_musicgen_model()
# Use the refined prompt as input
inputs = processor(
text=[descriptive_text],
padding=True,
return_tensors="pt"
)
audio_values = musicgen_model.generate(**inputs, max_new_tokens=512)
sampling_rate = musicgen_model.config.audio_encoder.sampling_rate
# Save & display the audio
audio_filename = "radio_imaging_output.wav"
scipy.io.wavfile.write(
audio_filename,
rate=sampling_rate,
data=audio_values[0, 0].numpy()
)
st.success("Audio successfully generated!")
st.audio(audio_filename)
except Exception as e:
st.error(f"Error while generating audio: {e}")
# ---------------------------------------------------------------------
# Footer Section
# ---------------------------------------------------------------------
st.markdown("---")
st.markdown(
"<div class='footer-note'>"
"✅ Built with Llama 2 & MusicGen · "
"Created for radio imaging producers · "
"Feedback welcome at <a href='https://bilsimaging.com' target='_blank'>Bilsimaging</a>!"
"</div>",
unsafe_allow_html=True
)
# Hide Streamlit's default menu and footer if you wish
st.markdown("<style>#MainMenu {visibility: hidden;} footer {visibility: hidden;}</style>", unsafe_allow_html=True)