atlonxp commited on
Commit
14e923a
·
unverified ·
1 Parent(s): 5cc0253

Update dataset.py

Browse files

change recursive approach to while loop, avoiding potential memory leak.

Files changed (1) hide show
  1. src/f5_tts/model/dataset.py +25 -19
src/f5_tts/model/dataset.py CHANGED
@@ -127,38 +127,44 @@ class CustomDataset(Dataset):
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
- row = self.data[index]
131
- audio_path = row["audio_path"]
132
- text = row["text"]
133
- duration = row["duration"]
134
-
 
 
 
 
 
 
 
 
135
  if self.preprocessed_mel:
136
  mel_spec = torch.tensor(row["mel_spec"])
137
-
138
  else:
139
  audio, source_sample_rate = torchaudio.load(audio_path)
 
 
140
  if audio.shape[0] > 1:
141
  audio = torch.mean(audio, dim=0, keepdim=True)
142
-
143
- if duration > 30 or duration < 0.3:
144
- return self.__getitem__((index + 1) % len(self.data))
145
-
146
  if source_sample_rate != self.target_sample_rate:
147
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148
  audio = resampler(audio)
149
-
 
150
  mel_spec = self.mel_spectrogram(audio)
151
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152
-
153
- return dict(
154
- mel_spec=mel_spec,
155
- text=text,
156
- )
157
 
158
 
159
  # Dynamic Batch Sampler
160
-
161
-
162
  class DynamicBatchSampler(Sampler[list[int]]):
163
  """Extension of Sampler that will do the following:
164
  1. Change the batch size (essentially number of sequences)
 
127
  return len(self.data)
128
 
129
  def __getitem__(self, index):
130
+ while True:
131
+ row = self.data[index]
132
+ audio_path = row["audio_path"]
133
+ text = row["text"]
134
+ duration = row["duration"]
135
+
136
+ # Check if the duration is within the acceptable range
137
+ if 0.3 <= duration <= 30:
138
+ break # Valid sample found, exit the loop
139
+
140
+ # Move to the next index and wrap around if necessary
141
+ index = (index + 1) % len(self.data)
142
+
143
  if self.preprocessed_mel:
144
  mel_spec = torch.tensor(row["mel_spec"])
 
145
  else:
146
  audio, source_sample_rate = torchaudio.load(audio_path)
147
+
148
+ # If the audio has multiple channels, convert it to mono
149
  if audio.shape[0] > 1:
150
  audio = torch.mean(audio, dim=0, keepdim=True)
151
+
152
+ # Resample the audio if necessary
 
 
153
  if source_sample_rate != self.target_sample_rate:
154
  resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
155
  audio = resampler(audio)
156
+
157
+ # Compute the mel spectrogram
158
  mel_spec = self.mel_spectrogram(audio)
159
+ mel_spec = mel_spec.squeeze(0) # Convert from (1, D, T) to (D, T)
160
+
161
+ return {
162
+ "mel_spec": mel_spec,
163
+ "text": text,
164
+ }
165
 
166
 
167
  # Dynamic Batch Sampler
 
 
168
  class DynamicBatchSampler(Sampler[list[int]]):
169
  """Extension of Sampler that will do the following:
170
  1. Change the batch size (essentially number of sequences)