vincenthugging's picture
๐ŸŽ™๏ธ Complete MOSS-TTSD Enhanced Space implementation
136b79d
raw
history blame
4.71 kB
import logging
import torchaudio
import os
import sys
import glob
import debugpy
import torch
import numpy as np
import re
def count_params_by_module(model_name, model):
logging.info(f"Counting num_parameters of {model_name}:")
param_stats = {}
total_params = 0 # Count total parameters
total_requires_grad_params = 0 # Count parameters with requires_grad=True
total_no_grad_params = 0 # Count parameters with requires_grad=False
for name, param in model.named_parameters():
module_name = name.split('.')[0]
if module_name not in param_stats:
param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0}
param_num = param.numel()
param_stats[module_name]['total'] += param_num
total_params += param_num
if param.requires_grad:
param_stats[module_name]['requires_grad'] += param_num
total_requires_grad_params += param_num
else:
param_stats[module_name]['no_grad'] += param_num
total_no_grad_params += param_num
# Calculate maximum width for each column
max_module_name_length = max(len(module) for module in param_stats)
max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values())
# Output parameter statistics for each module
for module, stats in param_stats.items():
logging.info(f"\t{module:<{max_module_name_length}}: "
f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M")
# Output total parameter statistics
logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters")
logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters")
logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters")
logging.info(f"################################################################")
def load_and_resample_audio(audio_path, target_sample_rate):
wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor
if raw_sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor
return wav.squeeze()
def set_logging():
rank = os.environ.get("RANK", 0)
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s",
)
def waiting_for_debug(ip, port):
rank = os.environ.get("RANK", "0")
debugpy.listen((ip, port)) # Replace localhost with cluster node IP
logging.info(f"[rank = {rank}] Waiting for debugger attach...")
debugpy.wait_for_client()
logging.info(f"[rank = {rank}] Debugger attached")
def load_audio(audio_path, target_sample_rate):
# Load audio file, wav shape: (channels, time)
wav, raw_sample_rate = torchaudio.load(audio_path)
# If multi-channel, convert to mono by averaging across channels
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim
# Resample if necessary
if raw_sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate)
# Convert to numpy, add channel dimension, then back to tensor with desired shape
wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1)
wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time)
return wav
def save_audio(audio_outpath, audio_out, sample_rate):
torchaudio.save(
audio_outpath,
audio_out,
sample_rate=sample_rate,
encoding='PCM_S',
bits_per_sample=16
)
logging.info(f"Successfully saved audio at {audio_outpath}")
def find_audio_files(input_dir):
audio_extensions = ['*.flac', '*.mp3', '*.wav']
audios_input = []
for ext in audio_extensions:
audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
logging.info(f"Found {len(audios_input)} audio files in {input_dir}")
return sorted(audios_input)
def normalize_text(text):
# Remove all punctuation (including English and Chinese punctuation)
text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
# Convert to lowercase (effective for English, no effect on Chinese)
text = text.lower()
# Remove extra spaces
text = ' '.join(text.split())
return text