|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import random |
|
|
|
from transformers import AutoProcessor, MusicgenForConditionalGeneration |
|
|
|
COLORS = [ |
|
["#ff0000", "#00ff00"], |
|
["#00ff00", "#0000ff"], |
|
["#0000ff", "#ff0000"], |
|
] |
|
|
|
path = "facebook/musicgen-large" |
|
processor = AutoProcessor.from_pretrained(path) |
|
model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda") |
|
|
|
def predict(text): |
|
|
|
inputs = processor( |
|
text=[text], |
|
padding=True, |
|
return_tensors="pt",).to("cuda") |
|
|
|
with torch.autocast("cuda"): |
|
outputs = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=768) |
|
|
|
return (32000, outputs[0][0].cpu().numpy().astype(np.float16)), gr.make_waveform((32000, outputs[0].cpu().numpy().astype(np.float16).ravel()), bars_color=random.choice(COLORS), bar_count=75) |
|
|
|
|
|
title = "MusicGen" |
|
|
|
gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Textbox(label="Text prompt"), |
|
], |
|
outputs=["audio", "video"], |
|
title=title, |
|
theme="gradio/monochrome", |
|
).queue(max_size=10).launch() |
|
|