Yushen CHEN commited on
Commit
5f7944a
·
unverified ·
1 Parent(s): 07b100e

Update dataset.py, formatting

Browse files
Files changed (1) hide show
  1. src/f5_tts/model/dataset.py +12 -13
src/f5_tts/model/dataset.py CHANGED
@@ -133,31 +133,30 @@ class CustomDataset(Dataset):
133
  text = row["text"]
134
  duration = row["duration"]
135
 
136
- # Check if the duration is within the acceptable range
137
  if 0.3 <= duration <= 30:
138
- break # Valid sample found, exit the loop
139
-
140
- # Move to the next index and wrap around if necessary
141
  index = (index + 1) % len(self.data)
142
-
143
  if self.preprocessed_mel:
144
  mel_spec = torch.tensor(row["mel_spec"])
145
  else:
146
  audio, source_sample_rate = torchaudio.load(audio_path)
147
-
148
- # If the audio has multiple channels, convert it to mono
149
  if audio.shape[0] > 1:
150
  audio = torch.mean(audio, dim=0, keepdim=True)
151
-
152
- # Resample the audio if necessary
153
  if source_sample_rate != self.target_sample_rate:
154
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
155
  audio = resampler(audio)
156
-
157
- # Compute the mel spectrogram
158
  mel_spec = self.mel_spectrogram(audio)
159
- mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T)
160
-
161
  return {
162
  "mel_spec": mel_spec,
163
  "text": text,
 
133
  text = row["text"]
134
  duration = row["duration"]
135
 
136
+ # filter by given length
137
  if 0.3 <= duration <= 30:
138
+ break # valid
139
+
 
140
  index = (index + 1) % len(self.data)
141
+
142
  if self.preprocessed_mel:
143
  mel_spec = torch.tensor(row["mel_spec"])
144
  else:
145
  audio, source_sample_rate = torchaudio.load(audio_path)
146
+
147
+ # make sure mono input
148
  if audio.shape[0] > 1:
149
  audio = torch.mean(audio, dim=0, keepdim=True)
150
+
151
+ # resample if necessary
152
  if source_sample_rate != self.target_sample_rate:
153
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
154
  audio = resampler(audio)
155
+
156
+ # to mel spectrogram
157
  mel_spec = self.mel_spectrogram(audio)
158
+ mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
159
+
160
  return {
161
  "mel_spec": mel_spec,
162
  "text": text,