File size: 5,112 Bytes
afe1a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357478d
 
 
afe1a07
 
357478d
 
afe1a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357478d
 
 
afe1a07
357478d
 
afe1a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357478d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import torch
import gradio as gr
from einops import rearrange, repeat
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from scipy.io import wavfile
import glob
import random
import numpy as np
import re

# Import necessary functions and classes
from utils import load_t5, load_clap
from train import RF
from constants import build_model

# Disable flash attention if not available
torch.backends.cuda.enable_flash_sdp(False)

# Global variables to store loaded models and resources
global_model = None
global_t5 = None
global_clap = None
global_vae = None
global_vocoder = None
global_diffusion = None

# Set the models directory
MODELS_DIR = "/content/models"
GENERATIONS_DIR = "/content/generations"

def prepare(t5, clip, img, prompt):
    # ... [The prepare function remains unchanged]
    pass

def unload_current_model():
    global global_model
    if global_model is not None:
        del global_model
        torch.cuda.empty_cache()
        global_model = None

def load_model(model_name):
    global global_model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    unload_current_model()
    
    # Determine model size from filename
    if 'musicflow_b' in model_name:
        model_size = "base"
    elif 'musicflow_g' in model_name:
        model_size = "giant"
    elif 'musicflow_l' in model_name:
        model_size = "large"
    elif 'musicflow_s' in model_name:
        model_size = "small"
    else:
        model_size = "base"  # Default to base if unrecognized
    
    print(f"Loading {model_size} model: {model_name}")
    
    model_path = os.path.join(MODELS_DIR, model_name)
    global_model = build_model(model_size).to(device)
    state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
    global_model.load_state_dict(state_dict['ema'])
    global_model.eval()
    global_model.model_path = model_path

def load_resources():
    global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    print("Loading T5 and CLAP models...")
    global_t5 = load_t5(device, max_length=256)
    global_clap = load_clap(device, max_length=256)
    
    print("Loading VAE and vocoder...")
    global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
    global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
    
    print("Initializing diffusion...")
    global_diffusion = RF()
    
    print("Base resources loaded successfully!")

def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
    # ... [The generate_music function remains largely unchanged]
    # Update the output directory
    output_dir = GENERATIONS_DIR
    os.makedirs(output_dir, exist_ok=True)
    # ... [Rest of the function remains the same]
    pass

# Load base resources at startup
load_resources()

# Get list of .pt files in the models directory
model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
model_choices = [os.path.basename(f) for f in model_files]

# Ensure 'musicflow_b.pt' is the default choice if it exists
default_model = 'musicflow_b.pt'
if default_model in model_choices:
    model_choices.remove(default_model)
    model_choices.insert(0, default_model)

# Set up dark grey theme
theme = gr.themes.Monochrome(
    primary_hue="gray",
    secondary_hue="gray",
    neutral_hue="gray",
    radius_size=gr.themes.sizes.radius_sm,
)

# Gradio Interface
with gr.Blocks(theme=theme) as iface:
    gr.Markdown(
        """
        <div style="text-align: center;">
            <h1>FluxMusic Generator</h1>
            <p>Generate music based on text prompts using FluxMusic model.</p>
        </div>
        """)
    
    with gr.Row():
        model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0])
    
    with gr.Row():
        prompt = gr.Textbox(label="Prompt")
        seed = gr.Number(label="Seed", value=0)
    
    with gr.Row():
        cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
        steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
        duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
    
    generate_button = gr.Button("Generate Music")
    output_status = gr.Textbox(label="Generation Status")
    output_audio = gr.Audio(type="filepath")

    def on_model_change(model_name):
        load_model(model_name)

    model_dropdown.change(on_model_change, inputs=[model_dropdown])
    generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])

    # Load default model on startup
    default_model_path = os.path.join(MODELS_DIR, default_model)
    if os.path.exists(default_model_path):
        iface.load(lambda: load_model(default_model), inputs=None, outputs=None)

# Launch the interface
iface.launch()