AIPromoStudio / app.py
Bils's picture
Update app.py
c243adb verified
raw
history blame
9.98 kB
import streamlit as st
import requests
import torch
import scipy.io.wavfile
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
from io import BytesIO
from streamlit_lottie import st_lottie # pip install streamlit-lottie
# ---------------------------------------------------------------------
# 1) PAGE CONFIG
# ---------------------------------------------------------------------
st.set_page_config(
page_title="Radio Imaging AI MVP",
page_icon="🎧",
layout="wide"
)
# ---------------------------------------------------------------------
# 2) CUSTOM CSS / SPOTIFY-LIKE UI
# ---------------------------------------------------------------------
CUSTOM_CSS = """
<style>
/* Body styling for a dark, music-app vibe */
body {
background-color: #121212;
color: #FFFFFF;
font-family: "Helvetica Neue", sans-serif;
}
/* Main container width */
.block-container {
max-width: 1100px;
padding: 1rem 1.5rem;
}
/* Headings with a neon-ish green accent */
h1, h2, h3 {
color: #1DB954;
margin-bottom: 0.5rem;
}
/* Buttons: rounded, bright Spotify-like green on hover */
.stButton>button {
background-color: #1DB954 !important;
color: #FFFFFF !important;
border-radius: 24px;
border: none;
font-size: 16px !important;
padding: 0.6rem 1.2rem !important;
transition: background-color 0.3s ease;
}
.stButton>button:hover {
background-color: #1ed760 !important;
}
/* Sidebar: black background, white text */
.sidebar .sidebar-content {
background-color: #000000;
color: #FFFFFF;
}
/* Text inputs and text areas */
textarea, input, select {
border-radius: 8px !important;
background-color: #282828 !important;
color: #FFFFFF !important;
border: 1px solid #3e3e3e;
}
/* Audio player styling */
audio {
width: 100%;
margin-top: 1rem;
}
/* Lottie container styling */
.lottie-container {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
/* Footer styling */
.footer-note {
text-align: center;
font-size: 14px;
opacity: 0.7;
margin-top: 2rem;
}
/* Hide Streamlit's default branding if desired */
#MainMenu, footer {visibility: hidden;}
</style>
"""
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 3) HELPER: LOAD LOTTIE ANIMATION
# ---------------------------------------------------------------------
@st.cache_data
def load_lottie_url(url: str):
"""
Fetch Lottie JSON for animations.
"""
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
# Example Lottie animation (radio waves / music eq, etc.)
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
lottie_animation = load_lottie_url(LOTTIE_URL)
# ---------------------------------------------------------------------
# 4) SIDEBAR: "LIBRARY" NAVIGATION (MIMICS SPOTIFY)
# ---------------------------------------------------------------------
with st.sidebar:
st.header("🎚 Radio Library")
st.write("**My Stations**")
st.write("- Favorites")
st.write("- Recently Generated")
st.write("- Top Hits")
st.write("---")
st.write("**Settings**")
st.markdown("<br>", unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 5) HEADER SECTION WITH LOTS OF FLARE
# ---------------------------------------------------------------------
col1, col2 = st.columns([3, 2], gap="large")
with col1:
st.title("AI Radio Imaging MVP")
st.subheader("Llama-Driven Promo Scripts, MusicGen Audio")
st.markdown(
"""
Create **radio imaging promos** and **jingles** with a minimal but creative MVP.
This app:
- Uses a (hypothetical) [Llama 3] model for **script generation**.
- Uses Meta's [MusicGen](https://github.com/facebookresearch/audiocraft) for **audio**.
- Features a Spotify-like UI & Lottie animations for a modern user experience.
"""
)
with col2:
if lottie_animation:
with st.container():
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie")
else:
st.write("*No animation loaded.*")
st.markdown("---")
# ---------------------------------------------------------------------
# 6) PROMPT INPUT & MODEL SELECTION
# ---------------------------------------------------------------------
st.subheader("πŸŽ™ Step 1: Briefly Describe Your Promo Idea")
prompt = st.text_area(
"E.g. 'A 15-second upbeat jingle with a catchy hook for a Top 40 morning show'",
height=120
)
col_model, col_device = st.columns(2)
with col_model:
llama_model_id = st.text_input(
"Llama Model (Hugging Face ID)",
value="meta-llama/Llama-3.3-70B-Instruct", # Replace with a real model
help="If non-existent, you'll see errors. Try Llama 2 (e.g. meta-llama/Llama-2-7b-chat-hf)."
)
with col_device:
device_option = st.selectbox(
"Choose Device",
["auto", "cpu"],
help="For GPU usage, pick 'auto'. CPU can be slow for big models."
)
# ---------------------------------------------------------------------
# 7) BUTTON: GENERATE RADIO SCRIPT WITH LLAMA
# ---------------------------------------------------------------------
if st.button("πŸ“ Generate Promo Script"):
if not prompt.strip():
st.error("Please enter a radio imaging concept first.")
else:
with st.spinner("Generating script..."):
try:
# Load Llama pipeline
pipeline_llama = load_llama_pipeline(llama_model_id, device_option)
# Generate refined script
refined_text = generate_radio_script(prompt, pipeline_llama)
st.session_state["refined_script"] = refined_text
st.success("Promo script generated!")
st.write(refined_text)
except Exception as e:
st.error(f"Error during Llama generation: {e}")
st.markdown("---")
# ---------------------------------------------------------------------
# 8) AUDIO GENERATION: MUSICGEN
# ---------------------------------------------------------------------
st.subheader("🎢 Step 2: Generate Your Radio Audio")
audio_tokens = st.slider("MusicGen Max Tokens (Track Length)", 128, 1024, 512, 64)
if st.button("🎧 Create Audio with MusicGen"):
# Check if we have a refined script
if "refined_script" not in st.session_state:
st.error("Please generate a promo script first.")
else:
with st.spinner("Generating audio..."):
try:
# Load MusicGen
mg_model, mg_processor = load_musicgen_model()
descriptive_text = st.session_state["refined_script"]
# Prepare model input
inputs = mg_processor(
text=[descriptive_text],
return_tensors="pt",
padding=True
)
# Generate audio
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_tokens)
sr = mg_model.config.audio_encoder.sampling_rate
# Save audio to WAV
out_filename = "radio_imaging_output.wav"
scipy.io.wavfile.write(out_filename, rate=sr, data=audio_values[0,0].numpy())
st.success("Audio created! Press play to listen:")
st.audio(out_filename)
except Exception as e:
st.error(f"Error generating audio: {e}")
# ---------------------------------------------------------------------
# 9) HELPER FUNCTIONS
# ---------------------------------------------------------------------
@st.cache_resource
def load_llama_pipeline(model_id: str, device: str):
"""
Load the Llama model & pipeline.
"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "auto" else torch.float32,
device_map=device
)
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=device
)
return text_gen_pipeline
def generate_radio_script(user_input: str, pipeline_llama) -> str:
"""
Use Llama to refine the user's input into a brief but creative radio imaging script.
"""
system_prompt = (
"You are a top-tier radio imaging producer. "
"Take the user's concept and craft a short, high-impact promo script. "
"Include style, tone, and potential CTA if relevant."
)
full_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
output = pipeline_llama(
full_prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.9
)[0]["generated_text"]
# Attempt to isolate the final script portion
if "Refined script:" in output:
output = output.split("Refined script:", 1)[-1].strip()
output += "\n\n(Generated by Llama in Radio Imaging MVP)"
return output
@st.cache_resource
def load_musicgen_model():
"""
Load MusicGen (small version).
"""
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return mg_model, mg_processor
# ---------------------------------------------------------------------
# 10) FOOTER
# ---------------------------------------------------------------------
st.markdown("---")
st.markdown(
"""
<div class="footer-note">
&copy; 2025 Radio Imaging MVP &ndash; Built with Llama & MusicGen. <br>
Inspired by Spotify's UI for a sleek, modern experience.
</div>
""",
unsafe_allow_html=True
)