Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
from einops import rearrange, repeat | |
from diffusers import AutoencoderKL | |
from transformers import SpeechT5HifiGan | |
from scipy.io import wavfile | |
import glob | |
import random | |
import numpy as np | |
import re | |
# Import necessary functions and classes | |
from utils import load_t5, load_clap | |
from train import RF | |
from constants import build_model | |
# Disable flash attention if not available | |
torch.backends.cuda.enable_flash_sdp(False) | |
# Global variables to store loaded models and resources | |
global_model = None | |
global_t5 = None | |
global_clap = None | |
global_vae = None | |
global_vocoder = None | |
global_diffusion = None | |
# Set the models directory | |
MODELS_DIR = "/content/models" | |
GENERATIONS_DIR = "/content/generations" | |
def prepare(t5, clip, img, prompt): | |
# ... [The prepare function remains unchanged] | |
pass | |
def unload_current_model(): | |
global global_model | |
if global_model is not None: | |
del global_model | |
torch.cuda.empty_cache() | |
global_model = None | |
def load_model(model_name): | |
global global_model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
unload_current_model() | |
# Determine model size from filename | |
if 'musicflow_b' in model_name: | |
model_size = "base" | |
elif 'musicflow_g' in model_name: | |
model_size = "giant" | |
elif 'musicflow_l' in model_name: | |
model_size = "large" | |
elif 'musicflow_s' in model_name: | |
model_size = "small" | |
else: | |
model_size = "base" # Default to base if unrecognized | |
print(f"Loading {model_size} model: {model_name}") | |
model_path = os.path.join(MODELS_DIR, model_name) | |
global_model = build_model(model_size).to(device) | |
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True) | |
global_model.load_state_dict(state_dict['ema']) | |
global_model.eval() | |
global_model.model_path = model_path | |
def load_resources(): | |
global global_t5, global_clap, global_vae, global_vocoder, global_diffusion | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Loading T5 and CLAP models...") | |
global_t5 = load_t5(device, max_length=256) | |
global_clap = load_clap(device, max_length=256) | |
print("Loading VAE and vocoder...") | |
global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device) | |
global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device) | |
print("Initializing diffusion...") | |
global_diffusion = RF() | |
print("Base resources loaded successfully!") | |
def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()): | |
# ... [The generate_music function remains largely unchanged] | |
# Update the output directory | |
output_dir = GENERATIONS_DIR | |
os.makedirs(output_dir, exist_ok=True) | |
# ... [Rest of the function remains the same] | |
pass | |
# Load base resources at startup | |
load_resources() | |
# Get list of .pt files in the models directory | |
model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt")) | |
model_choices = [os.path.basename(f) for f in model_files] | |
# Ensure 'musicflow_b.pt' is the default choice if it exists | |
default_model = 'musicflow_b.pt' | |
if default_model in model_choices: | |
model_choices.remove(default_model) | |
model_choices.insert(0, default_model) | |
# Set up dark grey theme | |
theme = gr.themes.Monochrome( | |
primary_hue="gray", | |
secondary_hue="gray", | |
neutral_hue="gray", | |
radius_size=gr.themes.sizes.radius_sm, | |
) | |
# Gradio Interface | |
with gr.Blocks(theme=theme) as iface: | |
gr.Markdown( | |
""" | |
<div style="text-align: center;"> | |
<h1>FluxMusic Generator</h1> | |
<p>Generate music based on text prompts using FluxMusic model.</p> | |
</div> | |
""") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0]) | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt") | |
seed = gr.Number(label="Seed", value=0) | |
with gr.Row(): | |
cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20) | |
steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100) | |
duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1) | |
generate_button = gr.Button("Generate Music") | |
output_status = gr.Textbox(label="Generation Status") | |
output_audio = gr.Audio(type="filepath") | |
def on_model_change(model_name): | |
load_model(model_name) | |
model_dropdown.change(on_model_change, inputs=[model_dropdown]) | |
generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio]) | |
# Load default model on startup | |
default_model_path = os.path.join(MODELS_DIR, default_model) | |
if os.path.exists(default_model_path): | |
iface.load(lambda: load_model(default_model), inputs=None, outputs=None) | |
# Launch the interface | |
iface.launch() | |