import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
COLORS = [ | |
["#ff0000", "#00ff00"], | |
["#00ff00", "#0000ff"], | |
["#0000ff", "#ff0000"], | |
] | |
path = "facebook/musicgen-large" | |
processor = AutoProcessor.from_pretrained(path) | |
model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda") | |
def predict(text): | |
inputs = processor( | |
text=[text], | |
padding=True, | |
return_tensors="pt",).to("cuda") | |
with torch.autocast("cuda"): | |
outputs = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512) | |
return (32000, outputs[0][0].cpu().numpy().astype(np.float16)), gr.make_waveform((32000, outputs[0].cpu().numpy().astype(np.float16).ravel()), bars_color=random.choice(COLORS), bar_count=75) | |
title = "MusicGen" | |
gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Textbox(label="Text prompt"), | |
], | |
outputs=["audio", "video"], | |
title=title, | |
theme="gradio/monochrome", | |
).queue(max_size=10).launch() | |