import logging import torchaudio import os import sys import glob import debugpy import torch import numpy as np import re def count_params_by_module(model_name, model): logging.info(f"Counting num_parameters of {model_name}:") param_stats = {} total_params = 0 # Count total parameters total_requires_grad_params = 0 # Count parameters with requires_grad=True total_no_grad_params = 0 # Count parameters with requires_grad=False for name, param in model.named_parameters(): module_name = name.split('.')[0] if module_name not in param_stats: param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0} param_num = param.numel() param_stats[module_name]['total'] += param_num total_params += param_num if param.requires_grad: param_stats[module_name]['requires_grad'] += param_num total_requires_grad_params += param_num else: param_stats[module_name]['no_grad'] += param_num total_no_grad_params += param_num # Calculate maximum width for each column max_module_name_length = max(len(module) for module in param_stats) max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values()) # Output parameter statistics for each module for module, stats in param_stats.items(): logging.info(f"\t{module:<{max_module_name_length}}: " f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, " f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, " f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M") # Output total parameter statistics logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters") logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters") logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters") logging.info(f"################################################################") def load_and_resample_audio(audio_path, target_sample_rate): wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor if raw_sample_rate != target_sample_rate: wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor return wav.squeeze() def set_logging(): rank = os.environ.get("RANK", 0) logging.basicConfig( level=logging.INFO, stream=sys.stdout, format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s", ) def waiting_for_debug(ip, port): rank = os.environ.get("RANK", "0") debugpy.listen((ip, port)) # Replace localhost with cluster node IP logging.info(f"[rank = {rank}] Waiting for debugger attach...") debugpy.wait_for_client() logging.info(f"[rank = {rank}] Debugger attached") def load_audio(audio_path, target_sample_rate): # Load audio file, wav shape: (channels, time) wav, raw_sample_rate = torchaudio.load(audio_path) # If multi-channel, convert to mono by averaging across channels if wav.shape[0] > 1: wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim # Resample if necessary if raw_sample_rate != target_sample_rate: wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # Convert to numpy, add channel dimension, then back to tensor with desired shape wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1) wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time) return wav def save_audio(audio_outpath, audio_out, sample_rate): torchaudio.save( audio_outpath, audio_out, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16 ) logging.info(f"Successfully saved audio at {audio_outpath}") def find_audio_files(input_dir): audio_extensions = ['*.flac', '*.mp3', '*.wav'] audios_input = [] for ext in audio_extensions: audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True)) logging.info(f"Found {len(audios_input)} audio files in {input_dir}") return sorted(audios_input) def normalize_text(text): # Remove all punctuation (including English and Chinese punctuation) text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text) # Convert to lowercase (effective for English, no effect on Chinese) text = text.lower() # Remove extra spaces text = ' '.join(text.split()) return text