atlonxp
		
	commited on
		
		
					Commit 
							
							·
						
						77ba66f
	
1
								Parent(s):
							
							6601302
								
Update dataset.py
Browse fileschange recursive approach to while loop, avoiding potential memory leak.
- src/f5_tts/model/dataset.py +25 -19
    	
        src/f5_tts/model/dataset.py
    CHANGED
    
    | @@ -127,38 +127,44 @@ class CustomDataset(Dataset): | |
| 127 | 
             
                    return len(self.data)
         | 
| 128 |  | 
| 129 | 
             
                def __getitem__(self, index):
         | 
| 130 | 
            -
                     | 
| 131 | 
            -
             | 
| 132 | 
            -
             | 
| 133 | 
            -
             | 
| 134 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 135 | 
             
                    if self.preprocessed_mel:
         | 
| 136 | 
             
                        mel_spec = torch.tensor(row["mel_spec"])
         | 
| 137 | 
            -
             | 
| 138 | 
             
                    else:
         | 
| 139 | 
             
                        audio, source_sample_rate = torchaudio.load(audio_path)
         | 
|  | |
|  | |
| 140 | 
             
                        if audio.shape[0] > 1:
         | 
| 141 | 
             
                            audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 142 | 
            -
             | 
| 143 | 
            -
                         | 
| 144 | 
            -
                            return self.__getitem__((index + 1) % len(self.data))
         | 
| 145 | 
            -
             | 
| 146 | 
             
                        if source_sample_rate != self.target_sample_rate:
         | 
| 147 | 
             
                            resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
         | 
| 148 | 
             
                            audio = resampler(audio)
         | 
| 149 | 
            -
             | 
|  | |
| 150 | 
             
                        mel_spec = self.mel_spectrogram(audio)
         | 
| 151 | 
            -
                        mel_spec = mel_spec.squeeze(0)  #  | 
| 152 | 
            -
             | 
| 153 | 
            -
                    return  | 
| 154 | 
            -
                        mel_spec | 
| 155 | 
            -
                        text | 
| 156 | 
            -
                     | 
| 157 |  | 
| 158 |  | 
| 159 | 
             
            # Dynamic Batch Sampler
         | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
             
            class DynamicBatchSampler(Sampler[list[int]]):
         | 
| 163 | 
             
                """Extension of Sampler that will do the following:
         | 
| 164 | 
             
                1.  Change the batch size (essentially number of sequences)
         | 
|  | |
| 127 | 
             
                    return len(self.data)
         | 
| 128 |  | 
| 129 | 
             
                def __getitem__(self, index):
         | 
| 130 | 
            +
                    while True:
         | 
| 131 | 
            +
                        row = self.data[index]
         | 
| 132 | 
            +
                        audio_path = row["audio_path"]
         | 
| 133 | 
            +
                        text = row["text"]
         | 
| 134 | 
            +
                        duration = row["duration"]
         | 
| 135 | 
            +
                
         | 
| 136 | 
            +
                        # Check if the duration is within the acceptable range
         | 
| 137 | 
            +
                        if 0.3 <= duration <= 30:
         | 
| 138 | 
            +
                            break  # Valid sample found, exit the loop
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                        # Move to the next index and wrap around if necessary
         | 
| 141 | 
            +
                        index = (index + 1) % len(self.data)
         | 
| 142 | 
            +
                
         | 
| 143 | 
             
                    if self.preprocessed_mel:
         | 
| 144 | 
             
                        mel_spec = torch.tensor(row["mel_spec"])
         | 
|  | |
| 145 | 
             
                    else:
         | 
| 146 | 
             
                        audio, source_sample_rate = torchaudio.load(audio_path)
         | 
| 147 | 
            +
                
         | 
| 148 | 
            +
                        # If the audio has multiple channels, convert it to mono
         | 
| 149 | 
             
                        if audio.shape[0] > 1:
         | 
| 150 | 
             
                            audio = torch.mean(audio, dim=0, keepdim=True)
         | 
| 151 | 
            +
                
         | 
| 152 | 
            +
                        # Resample the audio if necessary
         | 
|  | |
|  | |
| 153 | 
             
                        if source_sample_rate != self.target_sample_rate:
         | 
| 154 | 
             
                            resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
         | 
| 155 | 
             
                            audio = resampler(audio)
         | 
| 156 | 
            +
                
         | 
| 157 | 
            +
                        # Compute the mel spectrogram
         | 
| 158 | 
             
                        mel_spec = self.mel_spectrogram(audio)
         | 
| 159 | 
            +
                        mel_spec = mel_spec.squeeze(0)  # Convert from (1, D, T) to (D, T)
         | 
| 160 | 
            +
                
         | 
| 161 | 
            +
                    return {
         | 
| 162 | 
            +
                        "mel_spec": mel_spec,
         | 
| 163 | 
            +
                        "text": text,
         | 
| 164 | 
            +
                    }
         | 
| 165 |  | 
| 166 |  | 
| 167 | 
             
            # Dynamic Batch Sampler
         | 
|  | |
|  | |
| 168 | 
             
            class DynamicBatchSampler(Sampler[list[int]]):
         | 
| 169 | 
             
                """Extension of Sampler that will do the following:
         | 
| 170 | 
             
                1.  Change the batch size (essentially number of sequences)
         |