FluxMusicGUI / app.py
flosstradamus's picture
Update app.py
c5125f1 verified
raw
history blame
4.61 kB
import os
import torch
import gradio as gr
from einops import rearrange, repeat
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from scipy.io import wavfile
import glob
import random
import numpy as np
import re
import requests
import time
# ... (keep the imports and global variables)
def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
if global_model is None:
return "Please select and load a model first.", None
if global_t5 is None or global_clap is None or global_vae is None or global_vocoder is None or global_diffusion is None:
return "Resources not properly loaded. Please reload the page and try again.", None
if seed == 0:
seed = random.randint(1, 1000000)
print(f"Using seed: {seed}")
torch.manual_seed(seed)
torch.set_grad_enabled(False)
# Ensure we're using CPU if CUDA is not available
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
device = "cpu"
# Calculate the number of segments needed for the desired duration
segment_duration = 10 # Each segment is 10 seconds
num_segments = int(np.ceil(duration / segment_duration))
all_waveforms = []
for i in range(num_segments):
progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
# Use the same seed for all segments
torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
latent_size = (256, 16)
conds_txt = [prompt]
unconds_txt = ["low quality, gentle"]
L = len(conds_txt)
init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
_, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
# Implement batching for inference
images = []
for batch_start in range(0, img.shape[0], batch_size):
batch_end = min(batch_start + batch_size, img.shape[0])
batch_img = img[batch_start:batch_end]
batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
with torch.no_grad():
batch_images = global_diffusion.sample_with_xps(
global_model, batch_img, conds=batch_conds, null_cond=batch_unconds,
sample_steps=steps, cfg=cfg_scale
)
images.append(batch_images[-1])
images = torch.cat(images, dim=0)
images = rearrange(
images,
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=128,
w=8,
ph=2,
pw=2,)
latents = 1 / global_vae.config.scaling_factor * images
mel_spectrogram = global_vae.decode(latents).sample
x_i = mel_spectrogram[0]
if x_i.dim() == 4:
x_i = x_i.squeeze(1)
waveform = global_vocoder(x_i)
waveform = waveform[0].cpu().float().detach().numpy()
all_waveforms.append(waveform)
# ... (keep the rest of the function unchanged)
# ... (keep the rest of the file unchanged)
# Gradio Interface
with gr.Blocks(theme=theme) as iface:
# ... (keep the interface definition unchanged)
def on_load_model_click(model_name, device, url):
# Ensure we're using CPU if CUDA is not available
if device == "cuda" and not torch.cuda.is_available():
print("CUDA is not available. Falling back to CPU.")
device = "cpu"
resource_status = load_resources(device)
if "Failed" in resource_status:
return resource_status
if url:
result = load_model(None, device, model_url=url)
else:
result = load_model(model_name, device)
return result
load_model_button.click(on_load_model_click, inputs=[model_dropdown, device_choice, model_url], outputs=[model_status])
generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration, device_choice], outputs=[output_status, output_audio])
# Load default model and resources on startup
iface.load(lambda: on_load_model_click(default_model, "cpu", None), inputs=None, outputs=None)
# Launch the interface
iface.launch()