File size: 4,608 Bytes
afe1a07
 
 
 
 
 
 
b300542
afe1a07
 
 
6a91f5a
e39e85d
afe1a07
c5125f1
0dacaeb
e39e85d
0dacaeb
 
 
 
 
c4fb7a3
 
 
0dacaeb
 
 
 
 
 
 
c5125f1
 
 
 
 
0dacaeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e39e85d
0dacaeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5125f1
0dacaeb
c5125f1
afe1a07
 
 
c5125f1
afe1a07
6a91f5a
c5125f1
 
 
 
 
c4fb7a3
 
 
6a91f5a
 
 
 
368ac79
 
6a91f5a
 
771145b
6a91f5a
c4fb7a3
afe1a07
357478d
20d68fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import torch
import gradio as gr
from einops import rearrange, repeat
from diffusers import AutoencoderKL
from transformers import SpeechT5HifiGan
from scipy.io import wavfile
import glob
import random
import numpy as np
import re
import requests
import time

# ... (keep the imports and global variables)

def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=1, progress=gr.Progress()):
    global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
    
    if global_model is None:
        return "Please select and load a model first.", None
    
    if global_t5 is None or global_clap is None or global_vae is None or global_vocoder is None or global_diffusion is None:
        return "Resources not properly loaded. Please reload the page and try again.", None
    
    if seed == 0:
        seed = random.randint(1, 1000000)
    print(f"Using seed: {seed}")
    
    torch.manual_seed(seed)
    torch.set_grad_enabled(False)

    # Ensure we're using CPU if CUDA is not available
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA is not available. Falling back to CPU.")
        device = "cpu"

    # Calculate the number of segments needed for the desired duration
    segment_duration = 10  # Each segment is 10 seconds
    num_segments = int(np.ceil(duration / segment_duration))

    all_waveforms = []

    for i in range(num_segments):
        progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")

        # Use the same seed for all segments
        torch.manual_seed(seed + i)  # Add i to slightly vary each segment while maintaining consistency

        latent_size = (256, 16)
        conds_txt = [prompt]
        unconds_txt = ["low quality, gentle"]
        L = len(conds_txt)

        init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)

        img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
        _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)

        # Implement batching for inference
        images = []
        for batch_start in range(0, img.shape[0], batch_size):
            batch_end = min(batch_start + batch_size, img.shape[0])
            batch_img = img[batch_start:batch_end]
            batch_conds = {k: v[batch_start:batch_end] for k, v in conds.items()}
            batch_unconds = {k: v[batch_start:batch_end] for k, v in unconds.items()}
            
            with torch.no_grad():
                batch_images = global_diffusion.sample_with_xps(
                    global_model, batch_img, conds=batch_conds, null_cond=batch_unconds, 
                    sample_steps=steps, cfg=cfg_scale
                )
            images.append(batch_images[-1])
        
        images = torch.cat(images, dim=0)

        images = rearrange(
            images,
            "b (h w) (c ph pw) -> b c (h ph) (w pw)",
            h=128,
            w=8,
            ph=2,
            pw=2,)

        latents = 1 / global_vae.config.scaling_factor * images
        mel_spectrogram = global_vae.decode(latents).sample

        x_i = mel_spectrogram[0]
        if x_i.dim() == 4:
            x_i = x_i.squeeze(1)
        waveform = global_vocoder(x_i)
        waveform = waveform[0].cpu().float().detach().numpy()

        all_waveforms.append(waveform)

    # ... (keep the rest of the function unchanged)

# ... (keep the rest of the file unchanged)

# Gradio Interface
with gr.Blocks(theme=theme) as iface:
    # ... (keep the interface definition unchanged)

    def on_load_model_click(model_name, device, url):
        # Ensure we're using CPU if CUDA is not available
        if device == "cuda" and not torch.cuda.is_available():
            print("CUDA is not available. Falling back to CPU.")
            device = "cpu"
        
        resource_status = load_resources(device)
        if "Failed" in resource_status:
            return resource_status
        if url:
            result = load_model(None, device, model_url=url)
        else:
            result = load_model(model_name, device)
        return result

    load_model_button.click(on_load_model_click, inputs=[model_dropdown, device_choice, model_url], outputs=[model_status])
    generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration, device_choice], outputs=[output_status, output_audio])

    # Load default model and resources on startup
    iface.load(lambda: on_load_model_click(default_model, "cpu", None), inputs=None, outputs=None)

# Launch the interface
iface.launch()