osanseviero's picture
Update app.py
6e0dd6c
raw
history blame
1.35 kB
import numpy as np
import random
import subprocess
import tempfile
import torch
import gradio as gr
from transformers import AutoProcessor, MusicgenForConditionalGeneration
COLORS = [
["#ff0000", "#00ff00"],
["#00ff00", "#0000ff"],
["#0000ff", "#ff0000"],
]
path = "facebook/musicgen-small"
processor = AutoProcessor.from_pretrained(path)
model = MusicgenForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16).to("cuda")
def predict(text):
inputs = processor(
text=[text],
padding=True,
return_tensors="pt",).to("cuda")
with torch.autocast("cuda"):
outputs = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
with tempfile.NamedTemporaryFile() as tmp:
video_path = gr.make_waveform((32000, outputs[0].cpu().numpy().astype(np.float16).ravel()), bars_color=random.choice(COLORS), bar_count=75)
command = f'ffmpeg -y -i {video_path} -vf "scale=250:150" {tmp.name}'
subprocess.run(command.split())
return (32000, outputs[0][0].cpu().numpy().astype(np.float16)), tmp.name
title = "MusicGen"
gr.Interface(
fn=predict,
inputs=[
gr.Textbox(label="Text prompt"),
],
outputs=["audio", "video"],
title=title,
theme="gradio/monochrome",
).queue(max_size=10).launch()