Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import argparse, os, sys, glob | |
| import pathlib | |
| directory = pathlib.Path(os.getcwd()) | |
| print(directory) | |
| sys.path.append(str(directory)) | |
| import torch | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from ldm.util import instantiate_from_config | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from ldm.models.diffusion.plms import PLMSSampler | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import preprocess.n2s_by_openai as n2s | |
| from vocoder.bigvgan.models import VocoderBigVGAN | |
| import soundfile | |
| import torchaudio, math | |
| import gradio | |
| import gradio as gr | |
| def load_model_from_config(config, ckpt = None, verbose=True): | |
| model = instantiate_from_config(config.model) | |
| if ckpt: | |
| print(f"Loading model from {ckpt}") | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| sd = pl_sd["state_dict"] | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| print(m) | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| print(u) | |
| else: | |
| print(f"Note chat no ckpt is loaded !!!") | |
| model.cuda() | |
| model.eval() | |
| return model | |
| class GenSamples: | |
| def __init__(self,opt, model,outpath,config, vocoder = None,save_mel = True,save_wav = True) -> None: | |
| self.opt = opt | |
| self.model = model | |
| self.outpath = outpath | |
| if save_wav: | |
| assert vocoder is not None | |
| self.vocoder = vocoder | |
| self.save_mel = save_mel | |
| self.save_wav = save_wav | |
| self.channel_dim = self.model.channels | |
| self.config = config | |
| def gen_test_sample(self,prompt, mel_name = None,wav_name = None, gt=None, video=None):# prompt is {'ori_caption':’xxx‘,'struct_caption':'xxx'} | |
| uc = None | |
| record_dicts = [] | |
| if self.opt['scale'] != 1.0: | |
| try: # audiocaps | |
| uc = self.model.get_learned_conditioning({'ori_caption': "",'struct_caption': ""}) | |
| except: # audioset | |
| uc = self.model.get_learned_conditioning(prompt['ori_caption']) | |
| for n in range(self.opt['n_iter']): | |
| try: # audiocaps | |
| c = self.model.get_learned_conditioning(prompt) # shape:[1,77,1280],即还没有变成句子embedding,仍是每个单词的embedding | |
| except: # audioset | |
| c = self.model.get_learned_conditioning(prompt['ori_caption']) | |
| if self.channel_dim>0: | |
| shape = [self.channel_dim, self.opt['H'], self.opt['W']] # (z_dim, 80//2^x, 848//2^x) | |
| else: | |
| shape = [1, self.opt['H'], self.opt['W']] | |
| x0 = torch.randn(shape, device=self.model.device) | |
| if self.opt['scale'] == 1: # w/o cfg | |
| sample, _ = self.model.sample(c, 1, timesteps=self.opt['ddim_steps'], x_latent=x0) | |
| else: # cfg | |
| sample, _ = self.model.sample_cfg(c, self.opt['scale'], uc, 1, timesteps=self.opt['ddim_steps'], x_latent=x0) | |
| x_samples_ddim = self.model.decode_first_stage(sample) | |
| for idx,spec in enumerate(x_samples_ddim): | |
| spec = spec.squeeze(0).cpu().numpy() | |
| print(spec[0]) | |
| record_dict = {'caption':prompt['ori_caption'][0]} | |
| if self.save_mel: | |
| mel_path = os.path.join(self.outpath,mel_name+f'_{idx}.npy') | |
| np.save(mel_path,spec) | |
| record_dict['mel_path'] = mel_path | |
| if self.save_wav: | |
| wav = self.vocoder.vocode(spec) | |
| wav_path = os.path.join(self.outpath,wav_name+f'_{idx}.wav') | |
| soundfile.write(wav_path, wav, self.opt['sample_rate']) | |
| record_dict['audio_path'] = wav_path | |
| record_dicts.append(record_dict) | |
| return record_dicts | |
| def infer(ori_prompt, ddim_steps, scale, seed): | |
| # np.random.seed(seed) | |
| # torch.manual_seed(seed) | |
| prompt = dict(ori_caption=ori_prompt,struct_caption=f'<{ori_prompt}& all>') | |
| opt = { | |
| 'sample_rate': 16000, | |
| 'outdir': 'outputs/txt2music-samples', | |
| 'ddim_steps': ddim_steps, | |
| 'n_iter': 1, | |
| 'H': 20, | |
| 'W': 312, | |
| 'scale': scale, | |
| 'resume': 'useful_ckpts/music_generation/119.ckpt', | |
| 'base': 'configs/txt2music-cfm1-cfg-LargeDiT3.yaml', | |
| 'vocoder_ckpt': 'useful_ckpts/bigvnat', | |
| } | |
| config = OmegaConf.load(opt['base']) | |
| model = load_model_from_config(config, opt['resume']) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model = model.to(device) | |
| os.makedirs(opt['outdir'], exist_ok=True) | |
| vocoder = VocoderBigVGAN(opt['vocoder_ckpt'],device) | |
| generator = GenSamples(opt, model,opt['outdir'],config, vocoder,save_mel=False,save_wav=True) | |
| with torch.no_grad(): | |
| with model.ema_scope(): | |
| wav_name = f'{prompt["ori_caption"].strip().replace(" ", "-")}' | |
| generator.gen_test_sample(prompt,wav_name=wav_name) | |
| file_path = os.path.join(opt['outdir'],wav_name+'_0.wav') | |
| print(f"Your samples are ready and waiting four you here: \n{file_path} \nEnjoy.") | |
| return file_path | |
| def my_inference_function(text_prompt, ddim_steps, scale, seed): | |
| file_path = infer(text_prompt, ddim_steps, scale, seed) | |
| return file_path | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Markdown("## Make-An-Audio 3: Transforming Text into Audio via Flow-based Large Diffusion Transformers") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt: Input your text here. ") | |
| run_button = gr.Button() | |
| with gr.Accordion("Advanced options", open=False): | |
| ddim_steps = gr.Slider(label="ddim_steps", minimum=1, | |
| maximum=50, value=25, step=1) | |
| scale = gr.Slider( | |
| label="Guidance Scale:(Large => more relevant to text but the quality may drop)", minimum=0.1, maximum=8.0, value=3.0, step=0.1 | |
| ) | |
| seed = gr.Slider( | |
| label="Seed:Change this value (any integer number) will lead to a different generation result.", | |
| minimum=0, | |
| maximum=2147483647, | |
| step=1, | |
| value=44, | |
| ) | |
| with gr.Column(): | |
| outaudio = gr.Audio() | |
| run_button.click(fn=my_inference_function, inputs=[ | |
| prompt, ddim_steps, scale, seed], outputs=[outaudio]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Examples( | |
| examples = [['An amateur recording features a steel drum playing in a higher register',25,5,55], | |
| ['An instrumental song with a caribbean feel, happy mood, and featuring steel pan music, programmed percussion, and bass',25,5,55], | |
| ['This musical piece features a playful and emotionally melodic male vocal accompanied by piano',25,5,55], | |
| ['A eerie yet calming experimental electronic track featuring haunting synthesizer strings and pads',25,5,55], | |
| ['A slow tempo pop instrumental piece featuring only acoustic guitar with fingerstyle and percussive strumming techniques',25,5,55]], | |
| inputs = [prompt, ddim_steps, scale, seed], | |
| outputs = [outaudio] | |
| ) | |
| with gr.Column(): | |
| pass | |
| demo.launch() | |
| # gradio_interface = gradio.Interface( | |
| # fn = my_inference_function, | |
| # inputs = "text", | |
| # outputs = "audio" | |
| # ) | |
| # gradio_interface.launch() | |
| # text_prompt = 'An amateur recording features a steel drum playing in a higher register' | |
| # # text_prompt = 'A slow tempo pop instrumental piece featuring only acoustic guitar with fingerstyle and percussive strumming techniques' | |
| # ddim_steps=25 | |
| # scale=5.0 | |
| # seed=55 | |
| # my_inference_function(text_prompt, ddim_steps, scale, seed) | |