yhzx233's picture
feat: app.py
ea174b0
import logging
import torchaudio
import os
import sys
import glob
import debugpy
import torch
import numpy as np
import re
def count_params_by_module(model_name, model):
logging.info(f"Counting num_parameters of {model_name}:")
param_stats = {}
total_params = 0 # Count total parameters
total_requires_grad_params = 0 # Count parameters with requires_grad=True
total_no_grad_params = 0 # Count parameters with requires_grad=False
for name, param in model.named_parameters():
module_name = name.split('.')[0]
if module_name not in param_stats:
param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0}
param_num = param.numel()
param_stats[module_name]['total'] += param_num
total_params += param_num
if param.requires_grad:
param_stats[module_name]['requires_grad'] += param_num
total_requires_grad_params += param_num
else:
param_stats[module_name]['no_grad'] += param_num
total_no_grad_params += param_num
# Calculate maximum width for each column
max_module_name_length = max(len(module) for module in param_stats)
max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values())
# Output parameter statistics for each module
for module, stats in param_stats.items():
logging.info(f"\t{module:<{max_module_name_length}}: "
f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M")
# Output total parameter statistics
logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters")
logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters")
logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters")
logging.info(f"################################################################")
def load_and_resample_audio(audio_path, target_sample_rate):
wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor
if raw_sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor
return wav.squeeze()
def set_logging():
rank = os.environ.get("RANK", 0)
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s",
)
def waiting_for_debug(ip, port):
rank = os.environ.get("RANK", "0")
debugpy.listen((ip, port)) # Replace localhost with cluster node IP
logging.info(f"[rank = {rank}] Waiting for debugger attach...")
debugpy.wait_for_client()
logging.info(f"[rank = {rank}] Debugger attached")
def load_audio(audio_path, target_sample_rate):
# Load audio file, wav shape: (channels, time)
wav, raw_sample_rate = torchaudio.load(audio_path)
# If multi-channel, convert to mono by averaging across channels
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim
# Resample if necessary
if raw_sample_rate != target_sample_rate:
wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate)
# Convert to numpy, add channel dimension, then back to tensor with desired shape
wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1)
wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time)
return wav
def save_audio(audio_outpath, audio_out, sample_rate):
torchaudio.save(
audio_outpath,
audio_out,
sample_rate=sample_rate,
encoding='PCM_S',
bits_per_sample=16
)
logging.info(f"Successfully saved audio at {audio_outpath}")
def find_audio_files(input_dir):
audio_extensions = ['*.flac', '*.mp3', '*.wav']
audios_input = []
for ext in audio_extensions:
audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
logging.info(f"Found {len(audios_input)} audio files in {input_dir}")
return sorted(audios_input)
def normalize_text(text):
# Remove all punctuation (including English and Chinese punctuation)
text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
# Convert to lowercase (effective for English, no effect on Chinese)
text = text.lower()
# Remove extra spaces
text = ' '.join(text.split())
return text