Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from typing import Optional, List, Tuple, NamedTuple, Union | |
| from models import PipelineWrapper | |
| import torchaudio | |
| from audioldm.utils import get_duration | |
| MAX_DURATION = None | |
| class PromptEmbeddings(NamedTuple): | |
| embedding_hidden_states: torch.Tensor | |
| embedding_class_lables: torch.Tensor | |
| boolean_prompt_mask: torch.Tensor | |
| def load_audio(audio_path: Union[str, np.array], fn_STFT, left: int = 0, right: int = 0, | |
| device: Optional[torch.device] = None, | |
| return_wav: bool = False, stft: bool = False, model_sr: Optional[int] = None) -> torch.Tensor: | |
| if stft: # AudioLDM/tango loading to spectrogram | |
| if type(audio_path) is str: | |
| import audioldm | |
| import audioldm.audio | |
| duration = get_duration(audio_path) | |
| if MAX_DURATION is not None: | |
| duration = min(duration, MAX_DURATION) | |
| mel, _, wav = audioldm.audio.wav_to_fbank(audio_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT) | |
| mel = mel.unsqueeze(0) | |
| else: | |
| mel = audio_path | |
| c, h, w = mel.shape | |
| left = min(left, w-1) | |
| right = min(right, w - left - 1) | |
| mel = mel[:, :, left:w-right] | |
| mel = mel.unsqueeze(0).to(device) | |
| if return_wav: | |
| return mel, 16000, duration, wav | |
| return mel, model_sr, duration | |
| else: | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != model_sr: | |
| waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=model_sr) | |
| # waveform = waveform.numpy()[0, ...] | |
| def normalize_wav(waveform): | |
| waveform = waveform - torch.mean(waveform) | |
| waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) | |
| return waveform * 0.5 | |
| waveform = normalize_wav(waveform) | |
| # waveform = waveform[None, ...] | |
| # waveform = pad_wav(waveform, segment_length) | |
| # waveform = waveform[0, ...] | |
| waveform = torch.FloatTensor(waveform) | |
| if MAX_DURATION is not None: | |
| duration = min(waveform.shape[-1] / model_sr, MAX_DURATION) | |
| waveform = waveform[:, :int(duration * model_sr)] | |
| # cut waveform | |
| duration = waveform.shape[-1] / model_sr | |
| return waveform, model_sr, duration | |
| def get_height_of_spectrogram(length: int, ldm_stable: PipelineWrapper) -> int: | |
| vocoder_upsample_factor = np.prod(ldm_stable.model.vocoder.config.upsample_rates) / \ | |
| ldm_stable.model.vocoder.config.sampling_rate | |
| if length is None: | |
| length = ldm_stable.model.unet.config.sample_size * ldm_stable.model.vae_scale_factor * \ | |
| vocoder_upsample_factor | |
| height = int(length / vocoder_upsample_factor) | |
| # original_waveform_length = int(length * ldm_stable.model.vocoder.config.sampling_rate) | |
| if height % ldm_stable.model.vae_scale_factor != 0: | |
| height = int(np.ceil(height / ldm_stable.model.vae_scale_factor)) * ldm_stable.model.vae_scale_factor | |
| print( | |
| f"Audio length in seconds {length} is increased to {height * vocoder_upsample_factor} " | |
| f"so that it can be handled by the model. It will be cut to {length} after the " | |
| f"denoising process." | |
| ) | |
| return height | |
| def get_text_embeddings(target_prompt: List[str], target_neg_prompt: List[str], ldm_stable: PipelineWrapper | |
| ) -> Tuple[torch.Tensor, PromptEmbeddings, PromptEmbeddings]: | |
| text_embeddings_hidden_states, text_embeddings_class_labels, text_embeddings_boolean_prompt_mask = \ | |
| ldm_stable.encode_text(target_prompt) | |
| uncond_embedding_hidden_states, uncond_embedding_class_lables, uncond_boolean_prompt_mask = \ | |
| ldm_stable.encode_text(target_neg_prompt) | |
| text_emb = PromptEmbeddings(embedding_hidden_states=text_embeddings_hidden_states, | |
| boolean_prompt_mask=text_embeddings_boolean_prompt_mask, | |
| embedding_class_lables=text_embeddings_class_labels) | |
| uncond_emb = PromptEmbeddings(embedding_hidden_states=uncond_embedding_hidden_states, | |
| boolean_prompt_mask=uncond_boolean_prompt_mask, | |
| embedding_class_lables=uncond_embedding_class_lables) | |
| return text_embeddings_class_labels, text_emb, uncond_emb | |