Spaces:
Running
Running
import os | |
import requests | |
import torch | |
import scipy.io.wavfile | |
import streamlit as st | |
from io import BytesIO | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration | |
) | |
from streamlit_lottie import st_lottie | |
# --------------------------------------------------------------------- | |
# 1) PAGE CONFIG | |
# --------------------------------------------------------------------- | |
st.set_page_config( | |
page_title="Radio Imaging AI with Llama 3", | |
page_icon="π§", | |
layout="wide" | |
) | |
# --------------------------------------------------------------------- | |
# 2) CUSTOM CSS / SPOTIFY-LIKE UI | |
# --------------------------------------------------------------------- | |
CUSTOM_CSS = """ | |
<style> | |
/* Dark background with Spotify-like vibe */ | |
body { | |
background-color: #121212; | |
color: #FFFFFF; | |
font-family: "Helvetica Neue", sans-serif; | |
} | |
.block-container { | |
max-width: 1100px; | |
padding: 1rem 1.5rem; | |
} | |
h1, h2, h3 { | |
color: #1DB954; | |
margin-bottom: 0.5rem; | |
} | |
/* Rounded, bright green button on hover */ | |
.stButton>button { | |
background-color: #1DB954 !important; | |
color: #FFFFFF !important; | |
border-radius: 24px; | |
border: none; | |
font-size: 16px !important; | |
padding: 0.6rem 1.2rem !important; | |
transition: background-color 0.3s ease; | |
} | |
.stButton>button:hover { | |
background-color: #1ed760 !important; | |
} | |
/* Sidebar: black background */ | |
.sidebar .sidebar-content { | |
background-color: #000000; | |
color: #FFFFFF; | |
} | |
textarea, input, select { | |
border-radius: 8px !important; | |
background-color: #282828 !important; | |
color: #FFFFFF !important; | |
border: 1px solid #3e3e3e; | |
} | |
/* Audio styling */ | |
audio { | |
width: 100%; | |
margin-top: 1rem; | |
} | |
/* Lottie container */ | |
.lottie-container { | |
display: flex; | |
justify-content: center; | |
margin-bottom: 20px; | |
} | |
/* Footer */ | |
.footer-note { | |
text-align: center; | |
font-size: 14px; | |
opacity: 0.7; | |
margin-top: 2rem; | |
} | |
/* Hide Streamlit branding if you wish */ | |
#MainMenu, footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(CUSTOM_CSS, unsafe_allow_html=True) | |
# --------------------------------------------------------------------- | |
# 3) LOAD LOTTIE ANIMATION | |
# --------------------------------------------------------------------- | |
def load_lottie_url(url: str): | |
r = requests.get(url) | |
if r.status_code != 200: | |
return None | |
return r.json() | |
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json" | |
lottie_animation = load_lottie_url(LOTTIE_URL) | |
# --------------------------------------------------------------------- | |
# 4) LOAD LLAMA 3 (GATED MODEL) - WITH use_auth_token | |
# --------------------------------------------------------------------- | |
def load_llama_pipeline(model_id: str, device: str, token: str): | |
""" | |
Load the Llama 3 model from Hugging Face with a user token. | |
token: The HF access token from environment or secrets. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_id, | |
use_auth_token=token | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
use_auth_token=token, | |
torch_dtype=torch.float16 if device == "auto" else torch.float32, | |
device_map=device | |
) | |
text_gen_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map=device | |
) | |
return text_gen_pipeline | |
# --------------------------------------------------------------------- | |
# 5) REFINE SCRIPT (LLAMA) | |
# --------------------------------------------------------------------- | |
def generate_radio_script(user_input: str, pipeline_llama) -> str: | |
system_prompt = ( | |
"You are a top-tier radio imaging producer using Llama 3. " | |
"Take the user's concept and craft a short, creative promo script." | |
) | |
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:" | |
result = pipeline_llama( | |
combined_prompt, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.9 | |
) | |
output_text = result[0]["generated_text"] | |
if "Refined script:" in output_text: | |
output_text = output_text.split("Refined script:", 1)[-1].strip() | |
output_text += "\n\n(Generated by Llama 3 - Radio Imaging)" | |
return output_text | |
# --------------------------------------------------------------------- | |
# 6) LOAD MUSICGEN | |
# --------------------------------------------------------------------- | |
def load_musicgen_model(): | |
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
return mg_model, mg_processor | |
# --------------------------------------------------------------------- | |
# 7) SIDEBAR | |
# --------------------------------------------------------------------- | |
with st.sidebar: | |
st.header("π Radio Library") | |
st.write("**My Stations**") | |
st.write("- Favorites") | |
st.write("- Recently Generated") | |
st.write("- Top Hits") | |
st.write("---") | |
st.write("**Settings**") | |
st.markdown("<br>", unsafe_allow_html=True) | |
# --------------------------------------------------------------------- | |
# 8) HEADER | |
# --------------------------------------------------------------------- | |
col1, col2 = st.columns([3, 2], gap="large") | |
with col1: | |
st.title("AI Radio Imaging with Llama 3") | |
st.subheader("Gated Model + MusicGen Audio") | |
st.markdown( | |
""" | |
Create **radio imaging promos** and **jingles** with Llama 3 + MusicGen. | |
**Note**: | |
- You must have access to `"meta-llama/Llama-3-70B-Instruct"` on Hugging Face. | |
- You must provide your HF token in the environment (e.g., HF_TOKEN). | |
""" | |
) | |
with col2: | |
if lottie_animation: | |
with st.container(): | |
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie") | |
else: | |
st.write("*No animation loaded.*") | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# 9) SCRIPT GENERATION | |
# --------------------------------------------------------------------- | |
st.subheader("π Step 1: Describe Your Promo Idea") | |
prompt = st.text_area( | |
"Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'", | |
height=120 | |
) | |
col_model, col_device = st.columns(2) | |
with col_model: | |
llama_model_id = st.text_input( | |
"Llama 3 Model ID", | |
value="meta-llama/Llama-3-70B-Instruct", # Official ID if you have it | |
help="Use the exact name you see on the Hugging Face model page." | |
) | |
with col_device: | |
device_option = st.selectbox( | |
"Device (GPU vs CPU)", | |
["auto", "cpu"], | |
help="If you have GPU, 'auto' tries to use it; CPU might be slow." | |
) | |
# Grab your token from environment | |
my_token = os.getenv("HF_TOKEN") | |
if not my_token: | |
st.error("No HF_TOKEN found. Please set it in your HF Space secrets or environment variables.") | |
st.stop() | |
if st.button("π Generate Promo Script"): | |
if not prompt.strip(): | |
st.error("Please type some concept first.") | |
else: | |
with st.spinner("Generating script with Llama 3..."): | |
try: | |
llm_pipeline = load_llama_pipeline(llama_model_id, device_option, my_token) | |
final_script = generate_radio_script(prompt, llm_pipeline) | |
st.session_state["final_script"] = final_script | |
st.success("Promo script generated!") | |
st.write(final_script) | |
except Exception as e: | |
st.error(f"Llama generation error: {e}") | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# 10) AUDIO GENERATION: MUSICGEN | |
# --------------------------------------------------------------------- | |
st.subheader("πΆ Step 2: Generate Audio") | |
audio_length = st.slider("MusicGen Max Tokens (approx track length)", 128, 1024, 512, 64) | |
if st.button("π§ Create Audio with MusicGen"): | |
if "final_script" not in st.session_state: | |
st.error("No script found. Please generate a script first.") | |
else: | |
with st.spinner("Creating audio..."): | |
try: | |
mg_model, mg_processor = load_musicgen_model() | |
text_for_audio = st.session_state["final_script"] | |
inputs = mg_processor( | |
text=[text_for_audio], | |
padding=True, | |
return_tensors="pt" | |
) | |
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length) | |
sr = mg_model.config.audio_encoder.sampling_rate | |
outfile = "llama3_radio_jingle.wav" | |
scipy.io.wavfile.write(outfile, rate=sr, data=audio_values[0, 0].numpy()) | |
st.success("Audio generated! Press play below:") | |
st.audio(outfile) | |
except Exception as e: | |
st.error(f"MusicGen error: {e}") | |
# --------------------------------------------------------------------- | |
# 11) FOOTER | |
# --------------------------------------------------------------------- | |
st.markdown("---") | |
st.markdown( | |
""" | |
<div class="footer-note"> | |
Β© 2025 Radio Imaging with Llama 3 β Built using Hugging Face & Streamlit. <br> | |
Log in or provide <code>HF_TOKEN</code> and ensure access to <strong>meta-llama/Llama-3-70B-Instruct</strong>. | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |