Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,591 Bytes
17d10a7 a15d204 d448add db46bfb cf3593c c243adb cf3593c f0b5707 613bd9e f0b5707 7bbdf94 613bd9e f0b5707 613bd9e cf3593c d0384c8 cf3593c 17d10a7 cf3593c 17d10a7 cf3593c 17d10a7 d448add cf3593c d448add 17d10a7 f0b5707 cf3593c d448add 17d10a7 cf3593c d448add cf3593c 3fe530b 7bbdf94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
from scipy.io.wavfile import write
import tempfile
from dotenv import load_dotenv
import spaces
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
@spaces.GPU(duration=120)
def load_llama_pipeline_zero_gpu(model_id: str, token: str):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
return str(e)
@spaces.GPU(duration=120)
def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
try:
mg_model.to("cuda")
inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
mg_model.to("cpu")
sr = mg_model.config.audio_encoder.sampling_rate
audio_data = outputs[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
write(temp_wav.name, sr, normalized_audio)
return temp_wav.name
except Exception as e:
return f"Error generating audio: {e}"
with gr.Blocks() as demo:
gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
hf_token = gr.Textbox(label="Hugging Face Token", type="password")
audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
generate_button = gr.Button("Generate Promo Script and Audio")
script_output = gr.Textbox(label="Generated Script")
audio_output = gr.Audio(label="Generated Audio", type="filepath")
generate_button.click(
fn=lambda prompt, model_id, token, length: (prompt, None), # Simplify for demo
inputs=[user_prompt, llama_model_id, hf_token, audio_length],
outputs=[script_output, audio_output]
)
demo.launch(debug=True)
|