AIPromoStudio / app.py
Bils's picture
Update app.py
a15d204 verified
raw
history blame
9.54 kB
import os
import requests
import torch
import scipy.io.wavfile
import streamlit as st
from io import BytesIO
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
from streamlit_lottie import st_lottie
# ---------------------------------------------------------------------
# 1) PAGE CONFIG
# ---------------------------------------------------------------------
st.set_page_config(
page_title="Radio Imaging AI with Llama 3",
page_icon="🎧",
layout="wide"
)
# ---------------------------------------------------------------------
# 2) CUSTOM CSS / SPOTIFY-LIKE UI
# ---------------------------------------------------------------------
CUSTOM_CSS = """
<style>
/* Dark background with Spotify-like vibe */
body {
background-color: #121212;
color: #FFFFFF;
font-family: "Helvetica Neue", sans-serif;
}
.block-container {
max-width: 1100px;
padding: 1rem 1.5rem;
}
h1, h2, h3 {
color: #1DB954;
margin-bottom: 0.5rem;
}
/* Rounded, bright green button on hover */
.stButton>button {
background-color: #1DB954 !important;
color: #FFFFFF !important;
border-radius: 24px;
border: none;
font-size: 16px !important;
padding: 0.6rem 1.2rem !important;
transition: background-color 0.3s ease;
}
.stButton>button:hover {
background-color: #1ed760 !important;
}
/* Sidebar: black background */
.sidebar .sidebar-content {
background-color: #000000;
color: #FFFFFF;
}
textarea, input, select {
border-radius: 8px !important;
background-color: #282828 !important;
color: #FFFFFF !important;
border: 1px solid #3e3e3e;
}
/* Audio styling */
audio {
width: 100%;
margin-top: 1rem;
}
/* Lottie container */
.lottie-container {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
/* Footer */
.footer-note {
text-align: center;
font-size: 14px;
opacity: 0.7;
margin-top: 2rem;
}
/* Hide Streamlit branding if you wish */
#MainMenu, footer {visibility: hidden;}
</style>
"""
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 3) LOAD LOTTIE ANIMATION
# ---------------------------------------------------------------------
@st.cache_data
def load_lottie_url(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
lottie_animation = load_lottie_url(LOTTIE_URL)
# ---------------------------------------------------------------------
# 4) LOAD LLAMA 3 (GATED MODEL) - WITH use_auth_token
# ---------------------------------------------------------------------
@st.cache_resource
def load_llama_pipeline(model_id: str, device: str, token: str):
"""
Load the Llama 3 model from Hugging Face with a user token.
token: The HF access token from environment or secrets.
"""
tokenizer = AutoTokenizer.from_pretrained(
model_id,
use_auth_token=token
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16 if device == "auto" else torch.float32,
device_map=device
)
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=device
)
return text_gen_pipeline
# ---------------------------------------------------------------------
# 5) REFINE SCRIPT (LLAMA)
# ---------------------------------------------------------------------
def generate_radio_script(user_input: str, pipeline_llama) -> str:
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
result = pipeline_llama(
combined_prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.9
)
output_text = result[0]["generated_text"]
if "Refined script:" in output_text:
output_text = output_text.split("Refined script:", 1)[-1].strip()
output_text += "\n\n(Generated by Llama 3 - Radio Imaging)"
return output_text
# ---------------------------------------------------------------------
# 6) LOAD MUSICGEN
# ---------------------------------------------------------------------
@st.cache_resource
def load_musicgen_model():
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return mg_model, mg_processor
# ---------------------------------------------------------------------
# 7) SIDEBAR
# ---------------------------------------------------------------------
with st.sidebar:
st.header("🎚 Radio Library")
st.write("**My Stations**")
st.write("- Favorites")
st.write("- Recently Generated")
st.write("- Top Hits")
st.write("---")
st.write("**Settings**")
st.markdown("<br>", unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 8) HEADER
# ---------------------------------------------------------------------
col1, col2 = st.columns([3, 2], gap="large")
with col1:
st.title("AI Radio Imaging with Llama 3")
st.subheader("Gated Model + MusicGen Audio")
st.markdown(
"""
Create **radio imaging promos** and **jingles** with Llama 3 + MusicGen.
**Note**:
- You must have access to `"meta-llama/Llama-3-70B-Instruct"` on Hugging Face.
- You must provide your HF token in the environment (e.g., HF_TOKEN).
"""
)
with col2:
if lottie_animation:
with st.container():
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie")
else:
st.write("*No animation loaded.*")
st.markdown("---")
# ---------------------------------------------------------------------
# 9) SCRIPT GENERATION
# ---------------------------------------------------------------------
st.subheader("πŸŽ™ Step 1: Describe Your Promo Idea")
prompt = st.text_area(
"Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'",
height=120
)
col_model, col_device = st.columns(2)
with col_model:
llama_model_id = st.text_input(
"Llama 3 Model ID",
value="meta-llama/Llama-3-70B-Instruct", # Official ID if you have it
help="Use the exact name you see on the Hugging Face model page."
)
with col_device:
device_option = st.selectbox(
"Device (GPU vs CPU)",
["auto", "cpu"],
help="If you have GPU, 'auto' tries to use it; CPU might be slow."
)
# Grab your token from environment
my_token = os.getenv("HF_TOKEN")
if not my_token:
st.error("No HF_TOKEN found. Please set it in your HF Space secrets or environment variables.")
st.stop()
if st.button("πŸ“ Generate Promo Script"):
if not prompt.strip():
st.error("Please type some concept first.")
else:
with st.spinner("Generating script with Llama 3..."):
try:
llm_pipeline = load_llama_pipeline(llama_model_id, device_option, my_token)
final_script = generate_radio_script(prompt, llm_pipeline)
st.session_state["final_script"] = final_script
st.success("Promo script generated!")
st.write(final_script)
except Exception as e:
st.error(f"Llama generation error: {e}")
st.markdown("---")
# ---------------------------------------------------------------------
# 10) AUDIO GENERATION: MUSICGEN
# ---------------------------------------------------------------------
st.subheader("🎢 Step 2: Generate Audio")
audio_length = st.slider("MusicGen Max Tokens (approx track length)", 128, 1024, 512, 64)
if st.button("🎧 Create Audio with MusicGen"):
if "final_script" not in st.session_state:
st.error("No script found. Please generate a script first.")
else:
with st.spinner("Creating audio..."):
try:
mg_model, mg_processor = load_musicgen_model()
text_for_audio = st.session_state["final_script"]
inputs = mg_processor(
text=[text_for_audio],
padding=True,
return_tensors="pt"
)
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length)
sr = mg_model.config.audio_encoder.sampling_rate
outfile = "llama3_radio_jingle.wav"
scipy.io.wavfile.write(outfile, rate=sr, data=audio_values[0, 0].numpy())
st.success("Audio generated! Press play below:")
st.audio(outfile)
except Exception as e:
st.error(f"MusicGen error: {e}")
# ---------------------------------------------------------------------
# 11) FOOTER
# ---------------------------------------------------------------------
st.markdown("---")
st.markdown(
"""
<div class="footer-note">
Β© 2025 Radio Imaging with Llama 3 – Built using Hugging Face & Streamlit. <br>
Log in or provide <code>HF_TOKEN</code> and ensure access to <strong>meta-llama/Llama-3-70B-Instruct</strong>.
</div>
""",
unsafe_allow_html=True
)