AIPromoStudio / app.py
Bils's picture
Update app.py
613bd9e verified
raw
history blame
7.98 kB
import os
import requests
import torch
import scipy.io.wavfile as wav
import streamlit as st
from io import BytesIO
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
from streamlit_lottie import st_lottie
# ---------------------------------------------------------------------
# 1) PAGE CONFIGURATION
# ---------------------------------------------------------------------
st.set_page_config(
page_title="AI Radio Imaging with Llama 3",
page_icon="🎧",
layout="wide"
)
# ---------------------------------------------------------------------
# 2) CUSTOM CSS / UI DESIGN
# ---------------------------------------------------------------------
CUSTOM_CSS = """
<style>
body {
background-color: #121212;
color: #FFFFFF;
font-family: "Helvetica Neue", sans-serif;
}
.block-container {
max-width: 1100px;
padding: 1rem 1.5rem;
}
h1, h2, h3 {
color: #1DB954;
}
.stButton>button {
background-color: #1DB954 !important;
color: #FFFFFF !important;
border-radius: 24px;
padding: 0.6rem 1.2rem;
}
.stButton>button:hover {
background-color: #1ed760 !important;
}
textarea, input, select {
border-radius: 8px !important;
background-color: #282828 !important;
color: #FFFFFF !important;
}
audio {
width: 100%;
margin-top: 1rem;
}
.footer-note {
text-align: center;
font-size: 14px;
opacity: 0.7;
margin-top: 2rem;
}
#MainMenu, footer {visibility: hidden;}
</style>
"""
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 3) LOAD LOTTIE ANIMATION
# ---------------------------------------------------------------------
@st.cache_data
def load_lottie_url(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
lottie_animation = load_lottie_url(LOTTIE_URL)
# ---------------------------------------------------------------------
# 4) LOAD LLAMA 3 (GATED MODEL)
# ---------------------------------------------------------------------
@st.cache_resource
def load_llama_pipeline(model_id: str, device: str, token: str):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16 if device == "auto" else torch.float32,
device_map=device,
low_cpu_mem_usage=True
)
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=device
)
return text_gen_pipeline
except Exception as e:
st.error(f"Error loading Llama model: {e}")
raise
# ---------------------------------------------------------------------
# 5) GENERATE RADIO SCRIPT
# ---------------------------------------------------------------------
def generate_radio_script(user_input: str, pipeline_llama) -> str:
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
result = pipeline_llama(
combined_prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.9
)
output_text = result[0]["generated_text"]
if "Refined script:" in output_text:
output_text = output_text.split("Refined script:", 1)[-1].strip()
output_text += "\n\n(Generated by Llama 3 - Radio Imaging)"
return output_text
# ---------------------------------------------------------------------
# 6) LOAD MUSICGEN
# ---------------------------------------------------------------------
@st.cache_resource
def load_musicgen_model():
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return mg_model, mg_processor
# ---------------------------------------------------------------------
# 7) HEADER
# ---------------------------------------------------------------------
st.title("🎧 AI Radio Imaging with Llama 3")
st.subheader("Create engaging radio promos with Llama 3 + MusicGen")
st.markdown("""Create **radio imaging promos** and **jingles** easily. Ensure you have access to
**meta-llama/Meta-Llama-3-70B** on Hugging Face and provide your token below.""")
if lottie_animation:
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie")
st.markdown("---")
# ---------------------------------------------------------------------
# 8) USER INPUT
# ---------------------------------------------------------------------
st.subheader("🎀 Step 1: Describe Your Promo Idea")
prompt = st.text_area(
"Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'",
height=120
)
col_model, col_device = st.columns(2)
with col_model:
llama_model_id = st.text_input(
"Llama 3 Model ID",
value="meta-llama/Meta-Llama-3-70B",
help="Enter the exact model ID from Hugging Face."
)
with col_device:
device_option = st.selectbox(
"Device",
["auto", "cpu"],
help="Choose GPU (auto) or CPU."
)
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
st.error("No HF_TOKEN found. Please set it in your environment.")
st.stop()
if st.button("✍ Generate Promo Script"):
if not prompt.strip():
st.error("Please provide a concept first.")
else:
with st.spinner("Generating script..."):
try:
llama_pipeline = load_llama_pipeline(llama_model_id, device_option, hf_token)
final_script = generate_radio_script(prompt, llama_pipeline)
st.success("Promo script generated!")
st.text_area("Generated Script", value=final_script, height=200)
except Exception as e:
st.error(f"Llama generation error: {e}")
st.markdown("---")
# ---------------------------------------------------------------------
# 9) GENERATE AUDIO WITH MUSICGEN
# ---------------------------------------------------------------------
st.subheader("🎡 Step 2: Generate Audio")
audio_length = st.slider("Track Length (tokens)", 128, 1024, 512, 64)
if st.button("🎧 Create Audio"):
if "final_script" not in st.session_state:
st.error("Please generate a script first.")
else:
with st.spinner("Generating audio..."):
try:
mg_model, mg_processor = load_musicgen_model()
inputs = mg_processor(
text=[st.session_state["final_script"]],
padding=True,
return_tensors="pt"
)
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length)
sr = mg_model.config.audio_encoder.sampling_rate
output_file = "radio_jingle.wav"
audio_data = audio_values[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
wav.write(output_file, rate=sr, data=normalized_audio)
st.success("Audio generated! Play it below:")
st.audio(output_file)
except Exception as e:
st.error(f"MusicGen error: {e}")
# ---------------------------------------------------------------------
# 10) FOOTER
# ---------------------------------------------------------------------
st.markdown("---")
st.markdown(
"""
<div class="footer-note">
Β© 2025 AI Radio Imaging – Built with Hugging Face & Streamlit
</div>
""",
unsafe_allow_html=True
)