Sin2pi commited on
Commit
269a1c3
·
verified ·
1 Parent(s): 71514c3

Update echopipeline.py

Browse files
Files changed (1) hide show
  1. echopipeline.py +662 -662
echopipeline.py CHANGED
@@ -1,662 +1,662 @@
1
- import pyworld as pw
2
- import os
3
- import math
4
- import logging
5
- import torch
6
- import torchaudio
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from typing import Optional, Dict, Union, List, Tuple, Any
10
- from functools import partial
11
- from datetime import datetime
12
- from datasets import load_dataset, Audio, concatenate_datasets
13
- from transformers.trainer_seq2seq import Seq2SeqTrainer
14
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
15
- import evaluate
16
- from dataclasses import dataclass
17
-
18
- extractor = None
19
- tokenizer = None
20
- optimizer = None
21
- scheduler = None
22
- model = None
23
- Residual = None
24
- MultiheadA = None
25
- Echo = None
26
-
27
- metric = evaluate.load(path="wer")
28
-
29
- @dataclass
30
- class Dimensions:
31
- vocab: int
32
- text_ctx: int
33
- text_dims: int
34
- text_head: int
35
- text_idx: int
36
- mels: int
37
- aud_ctx: int
38
- aud_dims: int
39
- aud_head: int
40
- aud_idx: int
41
- act: str
42
- debug: List[str]
43
- cross_attn: bool
44
- features: List[str]
45
- f0_rotary: bool
46
-
47
- def align_f0(f0, ctx):
48
- ctx = torch.tensor(ctx)
49
- bat, length = f0.shape
50
- if length == ctx:
51
- return f0
52
- frames = length / ctx
53
- idx = torch.arange(ctx, device=f0.device)
54
- idx = (idx * frames).long()
55
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
56
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
57
-
58
- @dataclass
59
- class DataCollator:
60
- tokenizer: Any
61
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
62
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
63
- bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
64
-
65
- batch = {}
66
-
67
- if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
68
- spectrogram_list = [f["spectrogram"] for f in features]
69
- max_len_feat = max(f.shape[-1] for f in spectrogram_list)
70
- pad_spectrogram = []
71
- for feat in spectrogram_list:
72
- current_len = feat.shape[-1]
73
- padding = max_len_feat - current_len
74
- if padding > 0:
75
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
76
- else:
77
- pad_feat = feat
78
- pad_spectrogram.append(pad_feat)
79
- batch["spectrogram"] = torch.stack(pad_spectrogram)
80
-
81
- if "waveform" in features[0] and features[0]["waveform"] is not None:
82
- waveform_list = [f["waveform"] for f in features]
83
- max_len_wav = max(w.shape[-1] for w in waveform_list)
84
- pad_waveforms = []
85
- for wav in waveform_list:
86
- current_len = wav.shape[-1]
87
- padding = max_len_wav - current_len
88
- if padding > 0:
89
- if wav.ndim == 1:
90
- wav = wav.unsqueeze(0)
91
- pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
92
- else:
93
- pad_wav = wav
94
- pad_waveforms.append(pad_wav)
95
- batch["waveform"] = torch.stack(pad_waveforms)
96
-
97
- if "label" in features[0] and features[0]["label"] is not None:
98
- labels_list = [f["label"] for f in features]
99
- max_len = max(len(l) for l in labels_list)
100
- all_ids = []
101
- all_labels = []
102
-
103
- for label in labels_list:
104
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
105
- decoder_input = [bos_token_id] + label_list
106
- label_eos = label_list + [pad_token_id]
107
- input_len = max_len + 1 - len(decoder_input)
108
- label_len = max_len + 1 - len(label_eos)
109
- padded_input = decoder_input + [pad_token_id] * input_len
110
- padded_labels = label_eos + [pad_token_id] * label_len
111
- all_ids.append(padded_input)
112
- all_labels.append(padded_labels)
113
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
114
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
115
-
116
- if "pitch" in features[0] and features[0]["pitch"] is not None:
117
- pitch_list = [f["pitch"] for f in features]
118
- max_len_pitch = max(e.shape[-1] for e in pitch_list)
119
- pad_pitch = []
120
- for pitch in pitch_list:
121
- current_len = pitch.shape[-1]
122
- padding = max_len_pitch - current_len
123
- if padding > 0:
124
- pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
125
- else:
126
- pad_pitch_item = pitch
127
- pad_pitch.append(pad_pitch_item)
128
- batch["pitch"] = torch.stack(pad_pitch)
129
-
130
- if "f0" in features[0] and features[0]["f0"] is not None:
131
- f0_labels = batch.get("labels", None)
132
- if f0_labels is not None:
133
- target_length = f0_labels.shape[-1]
134
- aligned_list = []
135
- original_list = []
136
- for feature in features:
137
- f0 = feature["f0"]
138
- original_list.append(f0)
139
- if f0.shape[-1] != target_length:
140
- aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
141
- else:
142
- aligned_f0 = f0
143
- aligned_list.append(aligned_f0)
144
- batch["f0d"] = torch.stack(aligned_list) # [batch_size, target_length]
145
- batch["f0"] = torch.stack(original_list) # [batch_size, original_length]
146
-
147
- if "envelope" in features[0] and features[0]["envelope"] is not None:
148
- env_list = [f["envelope"] for f in features]
149
- max_len = max(f.shape[-1] for f in env_list)
150
- pad_env = []
151
- for feat in env_list:
152
- current_len = feat.shape[-1]
153
- padding = max_len_feat - current_len
154
- if padding > 0:
155
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
156
- else:
157
- pad_feat = feat
158
- pad_env.append(pad_feat)
159
- batch["envelope"] = torch.stack(pad_env)
160
-
161
- if "phase" in features[0] and features[0]["phase"] is not None:
162
- ph_list = [f["phase"] for f in features]
163
- max_len = max(f.shape[-1] for f in ph_list)
164
- pad_ph = []
165
- for feat in ph_list:
166
- current_len = feat.shape[-1]
167
- padding = max_len_feat - current_len
168
- if padding > 0:
169
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
170
- else:
171
- pad_feat = feat
172
- pad_ph.append(pad_feat)
173
- batch["phase"] = torch.stack(pad_ph)
174
- return batch
175
-
176
- def hilbert_transform(x):
177
- N = x.shape[-1]
178
- xf = torch.fft.rfft(x)
179
- h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
180
- if N % 2 == 0:
181
- h[0] = h[N//2] = 1
182
- h[1:N//2] = 2
183
- else:
184
- h[0] = 1
185
- h[1:(N+1)//2] = 2
186
- return torch.fft.irfft(xf * h, n=N)
187
-
188
- def analytic_signal(x):
189
- return x + 1j * hilbert_transform(x)
190
-
191
- def hilbert_transform_2d(x, dim=-1):
192
- N = x.shape[dim]
193
- if dim == -1 or dim == len(x.shape) - 1:
194
- xf = torch.fft.rfft(x)
195
- else:
196
- xf = torch.fft.rfft(x, dim=dim)
197
- h_shape = [1] * len(x.shape)
198
- h_shape[dim] = N // 2 + 1
199
- h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
200
- if dim == -1 or dim == len(x.shape) - 1:
201
- if N % 2 == 0:
202
- h[..., 0] = h[..., -1] = 1
203
- h[..., 1:-1] = 2
204
- else:
205
- h[..., 0] = 1
206
- h[..., 1:] = 2
207
- else:
208
- pass
209
- return torch.fft.irfft(xf * h, n=N, dim=dim)
210
-
211
- def hilbert_transform_true_2d(x):
212
- xf = torch.fft.rfft2(x)
213
- h1, h2 = torch.meshgrid(
214
- torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
215
- torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
216
- indexing='ij')
217
- h = -1j / (math.pi * (h1 + 1j*h2))
218
- h[0, 0] = 0
219
- return torch.fft.irfft2(xf * h.to(x.device))
220
-
221
- def process_spectrogram_with_hilbert(spec):
222
- analytic = spec + 1j * hilbert_transform(spec)
223
- envelope = torch.abs(analytic)
224
- phase = torch.angle(analytic)
225
- return envelope, phase
226
-
227
- def load_wave(wave_data, sample_rate):
228
- if isinstance(wave_data, str):
229
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
230
- elif isinstance(wave_data, dict):
231
- waveform = torch.tensor(data=wave_data["array"]).float()
232
- sr = wave_data["sampling_rate"]
233
- else:
234
- raise TypeError("Invalid wave_data format.")
235
-
236
- if waveform.dim() == 1:
237
- waveform = waveform.unsqueeze(0)
238
-
239
- if sr != sample_rate:
240
- original_length = waveform.shape[1]
241
- target_length = int(original_length * (sample_rate / sr))
242
-
243
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
244
- waveform = resampler(waveform)
245
-
246
- return waveform.flatten()
247
-
248
- def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
249
- hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
250
- pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
251
- norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
252
-
253
- dtype = torch.float32
254
- device = torch.device("cuda:0")
255
- audio = batch["audio"]
256
- sampling_rate = audio["sampling_rate"]
257
- sr = audio["sampling_rate"]
258
- wav = load_wave(wave_data=audio, sample_rate=sr)
259
-
260
- if spectrogram:
261
- transform = torchaudio.transforms.MelSpectrogram(
262
- f_max=fmax,
263
- f_min=fmin,
264
- n_mels=n_mels,
265
- sample_rate=sr,
266
- n_fft=n_fft,
267
- hop_length=hop_length,
268
- norm=norm,
269
- normalized=normalized,
270
- power=power,
271
- center=center,
272
- mel_scale=mel_scale,
273
- window_fn=window_fn,
274
- pad_mode=pad_mode)
275
-
276
- mel_spectrogram = transform(wav)
277
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
278
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
279
- spec = (log_mel + 4.0) / 4.0
280
- spec = torch.tensor(spec)
281
- batch["spectrogram"] = spec
282
-
283
- if hilbert:
284
- envelope_list = []
285
- phase_list = []
286
-
287
- for ch_idx in range(spec.shape[0]):
288
- envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
289
- envelope_list.append(envelope)
290
- phase_list.append(phase)
291
-
292
- batch["envelope"] = torch.stack(envelope_list)
293
- batch["phase"] = torch.stack(phase_list)
294
-
295
- wav_1d = wav.unsqueeze(0)
296
-
297
- if waveforms:
298
- batch["waveform"] = wav_1d
299
-
300
- if pitch:
301
- wav_np = wav.numpy().astype(np.float64)
302
- f0, t = pw.dio(wav_np, sampling_rate,
303
- frame_period=hop_length/sampling_rate*1000)
304
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
305
- f0 = torch.from_numpy(f0).float()
306
- batch["pitch"] = f0.unsqueeze(0)
307
-
308
- if frequency:
309
- wav_np = wav.numpy().astype(np.float64)
310
- f0, t = pw.dio(wav_np, sampling_rate,
311
- frame_period=hop_length/sampling_rate*1000)
312
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
313
- f0 = f0
314
- batch["f0"] = torch.from_numpy(f0).float()
315
-
316
- if spectrogram and waveforms and pitch:
317
- spec_mean = batch["spectrogram"].mean()
318
- spec_std = batch["spectrogram"].std() + 1e-6
319
- batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
320
-
321
- wav_mean = batch["waveform"].mean()
322
- wav_std = batch["waveform"].std() + 1e-6
323
- batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
324
-
325
- if batch["pitch"].max() > 1.0:
326
- pitch_min = 50.0
327
- pitch_max = 600.0
328
- batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
329
-
330
- batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
331
- return batch
332
-
333
- def compute_metrics(eval_pred, compute_result: bool = True,
334
- print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
335
-
336
- pred_logits = eval_pred.predictions
337
- label_ids = eval_pred.label_ids
338
-
339
- if hasattr(pred_logits, "cpu"):
340
- pred_logits = pred_logits.cpu()
341
- if hasattr(label_ids, "cpu"):
342
- label_ids = label_ids.cpu()
343
- if isinstance(pred_logits, tuple):
344
- pred_ids = pred_logits[0]
345
- else:
346
- pred_ids = pred_logits
347
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
348
- if not isinstance(pred_ids, torch.Tensor):
349
- pred_ids = torch.tensor(pred_ids)
350
- pred_ids = pred_ids.argmax(dim=-1)
351
- pred_ids = pred_ids.tolist()
352
-
353
- if hasattr(label_ids, "tolist"):
354
- label_ids = label_ids.tolist()
355
-
356
- label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
357
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
358
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
359
-
360
- if print_pred:
361
- for i in range(min(num_samples, len(pred_str))):
362
- print(f"Preds: {pred_str[i]}")
363
- print(f"Label: {label_str[i]}")
364
- print(f"preds: {pred_ids[i]}")
365
- print(f"label: {label_ids[i]}")
366
- print("--------------------------------")
367
-
368
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
369
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
370
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
371
-
372
- if model is None:
373
- global global_model
374
- if 'global_model' in globals():
375
- model = global_model
376
-
377
- if model is not None:
378
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
379
- if trainable_params > 0:
380
- efficiency_score = (100 - wer) / trainable_params
381
- else:
382
- print("Warning: Zero trainable parameters detected")
383
- efficiency_score = 0.0
384
- else:
385
- print("Warning: Model not available for parameter counting")
386
- trainable_params = 0.0
387
- efficiency_score = 0.0
388
-
389
- if hasattr(wer, "item"):
390
- wer = wer.item()
391
-
392
- metrics = {
393
- "wer": float(wer),
394
- "trainable_params_M": float(trainable_params),
395
- "efficiency_score": float(efficiency_score),
396
- }
397
-
398
- return metrics
399
-
400
- logger = logging.getLogger(__name__)
401
-
402
- def create_model(param: Dimensions) -> Echo:
403
- model = Echo(param).to('cuda')
404
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
405
- total_params = sum(p.numel() for p in model.parameters())
406
- logger.info(f"Trainable parameters: {trainable_params:,}")
407
- logger.info(f"Total parameters: {total_params:,}")
408
- print(f"Trainable parameters: {trainable_params:,}")
409
- print(f"Total parameters: {total_params:,}")
410
-
411
- return model
412
-
413
- def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
414
- from tokenizers import Tokenizer
415
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
416
- orig_encode = tokenizer.encode
417
- def enc(text, add_special_tokens=True):
418
- ids = orig_encode(text).ids
419
- if not add_special_tokens:
420
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
421
- ids = [id for id in ids if id not in sp_ids]
422
- return ids
423
- def bdec(ids_list, skip_special_tokens=True):
424
- results = []
425
- for ids in ids_list:
426
- if skip_special_tokens:
427
- ids = [id for id in ids if id not in [0, 1, 2]]
428
- results.append(tokenizer.decode(ids))
429
- return results
430
- def save_pretrained(save_dir):
431
- os.makedirs(save_dir, exist_ok=True)
432
- tokenizer.save(f"{save_dir}/tokenizer.json")
433
- tokenizer.encode = enc
434
- tokenizer.batch_decode = bdec
435
- tokenizer.save_pretrained = save_pretrained
436
- tokenizer.pad_token_id = 0
437
- tokenizer.bos_token_id = 1
438
- tokenizer.eos_token_id = 2
439
- return tokenizer
440
-
441
- def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
442
- if dataset_config is None:
443
- dataset_config = {
444
- "spectrogram": True,
445
- "waveforms": True,
446
- "pitch": True,
447
- "frequency": True,
448
- "downsamples": True,
449
- "hop_length": 128,
450
- "fmin": 50,
451
- "fmax": 2000,
452
- "n_mels": 128,
453
- "n_fft": 1024,
454
- "sampling_rate": 16000,
455
- }
456
-
457
- dataset = load_dataset(
458
- "google/fleurs",
459
- "en_us",
460
- token=token,
461
- trust_remote_code=True,
462
- streaming=False)
463
-
464
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
465
-
466
- if sanity_check:
467
- dataset = dataset["test"].take(10)
468
- dataset = dataset.select_columns(["audio", "transcription"])
469
- logger.info(f"Sanity dataset size: {dataset.num_rows}")
470
- print(f"Sanity dataset size: {dataset.num_rows}")
471
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
472
-
473
- dataset = dataset.map(
474
- function=prepare_fn,
475
- remove_columns=["audio", "transcription"]
476
- ).with_format(type="torch")
477
- train_dataset = dataset
478
- test_dataset = dataset
479
- else:
480
- def filter_func(x):
481
- return (0 < len(x["transcription"]) < 512 and
482
- len(x["audio"]["array"]) > 0 and
483
- len(x["audio"]["array"]) < 1500 * 160)
484
-
485
- dataset = dataset.filter(filter_func).shuffle(seed=4)
486
- logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
487
- print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
488
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
489
- columns_to_remove = list(next(iter(dataset.values())).features)
490
- train_dataset = dataset["train"]
491
- test_dataset = dataset["test"].take(50)
492
- logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
493
-
494
- train_dataset = train_dataset.map(
495
- function=prepare_fn,
496
- remove_columns=columns_to_remove
497
- ).with_format(type="torch")
498
-
499
- test_dataset = test_dataset.map(
500
- function=prepare_fn,
501
- remove_columns=columns_to_remove
502
- ).with_format(type="torch")
503
-
504
- return train_dataset, test_dataset
505
-
506
- def get_training_args(
507
- log_dir: str,
508
- batch_eval_metrics: bool = False,
509
- max_steps: int = 10,
510
- save_steps: int = 1000,
511
- eval_steps: int = 1,
512
- warmup_steps: int = 0,
513
- num_train_epochs: int = 1,
514
- logging_steps: int = 1,
515
- eval_on_start: bool = False,
516
- learning_rate: float = 1e-4,
517
- weight_decay: float = 0.01,
518
- max_grad_norm: float = 1.0,
519
- ) -> Seq2SeqTrainingArguments:
520
-
521
- return Seq2SeqTrainingArguments(
522
- output_dir=log_dir,
523
- per_device_train_batch_size=1,
524
- per_device_eval_batch_size=1,
525
- gradient_accumulation_steps=1,
526
- eval_accumulation_steps=1,
527
- tf32=True,
528
- bf16=True,
529
- eval_strategy="steps",
530
- save_strategy="steps",
531
- max_steps=max_steps,
532
- save_steps=save_steps,
533
- eval_steps=eval_steps,
534
- warmup_steps=warmup_steps,
535
- num_train_epochs=num_train_epochs,
536
- logging_steps=logging_steps,
537
- logging_dir=log_dir,
538
- logging_strategy="steps",
539
- report_to=["tensorboard"],
540
- push_to_hub=False,
541
- disable_tqdm=False,
542
- save_total_limit=1,
543
- label_names=["labels"],
544
- optim="adamw_torch",
545
- lr_scheduler_type="cosine",
546
- learning_rate=learning_rate,
547
- weight_decay=weight_decay,
548
- save_safetensors=False,
549
- eval_on_start=eval_on_start,
550
- batch_eval_metrics=batch_eval_metrics,
551
- max_grad_norm=max_grad_norm,
552
- )
553
-
554
- def main():
555
-
556
- token = ""
557
- log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
558
- os.makedirs(name=log_dir, exist_ok=True)
559
- tokenizer = setup_tokenizer(token)
560
-
561
- def sanity(sanity: bool):
562
-
563
- if sanity:
564
- training_args = get_training_args(
565
- log_dir,
566
- batch_eval_metrics = False,
567
- max_steps = 10,
568
- save_steps = 0,
569
- eval_steps = 1,
570
- warmup_steps = 0,
571
- logging_steps = 1,
572
- eval_on_start = False,
573
- learning_rate = 5e-6,
574
- weight_decay = 0.01,
575
- )
576
- else:
577
- training_args = get_training_args(
578
- log_dir,
579
- batch_eval_metrics = False,
580
- max_steps = 1000,
581
- save_steps = 1000,
582
- eval_steps = 100,
583
- warmup_steps = 100,
584
- logging_steps = 10,
585
- eval_on_start = False,
586
- learning_rate = 2.5e-4,
587
- weight_decay = 0.01,
588
- )
589
-
590
- return training_args
591
-
592
- param = Dimensions(
593
- mels=128,
594
- aud_ctx=1500,
595
- aud_head=4,
596
- aud_dims=512,
597
- aud_idx=4,
598
- vocab=40000,
599
- text_ctx=512,
600
- text_head=4,
601
- text_dims=512,
602
- text_idx=4,
603
- act="swish",
604
- debug={},#{"encoder", "decoder", "residual", "rotary"},
605
- cross_attn=True,
606
- f0_rotary=False,
607
- features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
608
- )
609
-
610
- sanity_check = False
611
- training_args = sanity(sanity_check)
612
- dataset_config = {
613
- "spectrogram": True,
614
- "waveforms": False,
615
- "pitch": False,
616
- "downsamples": False,
617
- "frequency": True,
618
- "hilbert": False,
619
- "hop_length": 128,
620
- "fmin": 150,
621
- "fmax": 2000,
622
- "n_mels": 128,
623
- "n_fft": 1024,
624
- "sampling_rate": 16000,
625
- "pad_mode": "constant",
626
- "center": True,
627
- "power": 2.0,
628
- "window_fn": torch.hann_window,
629
- "mel_scale": "htk",
630
- "norm": None,
631
- "normalized": False}
632
-
633
- model = create_model(param)
634
-
635
- global global_model
636
- global_model = model
637
-
638
- metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
639
- tokenizer=tokenizer, model=model)
640
-
641
- print(f"{'Sanity check' if sanity_check else 'Training'} mode")
642
- train_dataset, test_dataset = prepare_datasets(
643
- tokenizer=tokenizer,
644
- token=token,
645
- sanity_check=sanity_check,
646
- dataset_config=dataset_config)
647
-
648
- trainer = Seq2SeqTrainer(
649
- args=training_args,
650
- model=model,
651
- train_dataset=train_dataset,
652
- eval_dataset=test_dataset,
653
- data_collator=DataCollator(tokenizer=tokenizer),
654
- compute_metrics=metrics_fn,
655
- )
656
-
657
- model.init_weights()
658
- trainer.train()
659
-
660
- if __name__ == "__main__":
661
- main()
662
-
 
1
+ import pyworld as pw
2
+ import os
3
+ import math
4
+ import logging
5
+ import torch
6
+ import torchaudio
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ from typing import Optional, Dict, Union, List, Tuple, Any
10
+ from functools import partial
11
+ from datetime import datetime
12
+ from datasets import load_dataset, Audio, concatenate_datasets
13
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
14
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
15
+ import evaluate
16
+ from dataclasses import dataclass
17
+
18
+ extractor = None
19
+ tokenizer = None
20
+ optimizer = None
21
+ scheduler = None
22
+ model = None
23
+ Residual = None
24
+ MultiheadA = None
25
+ Echo = None
26
+
27
+ metric = evaluate.load(path="wer")
28
+
29
+ @dataclass
30
+ class Dimensions:
31
+ vocab: int
32
+ text_ctx: int
33
+ text_dims: int
34
+ text_head: int
35
+ text_idx: int
36
+ mels: int
37
+ aud_ctx: int
38
+ aud_dims: int
39
+ aud_head: int
40
+ aud_idx: int
41
+ act: str
42
+ debug: List[str]
43
+ cross_attn: bool
44
+ features: List[str]
45
+ f0_rotary: bool
46
+
47
+ def align_f0(f0, ctx):
48
+ ctx = torch.tensor(ctx)
49
+ bat, length = f0.shape
50
+ if length == ctx:
51
+ return f0
52
+ frames = length / ctx
53
+ idx = torch.arange(ctx, device=f0.device)
54
+ idx = (idx * frames).long()
55
+ batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
56
+ return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
57
+
58
+ @dataclass
59
+ class DataCollator:
60
+ tokenizer: Any
61
+ def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
62
+ pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
63
+ bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
64
+
65
+ batch = {}
66
+
67
+ if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
68
+ spectrogram_list = [f["spectrogram"] for f in features]
69
+ max_len_feat = max(f.shape[-1] for f in spectrogram_list)
70
+ pad_spectrogram = []
71
+ for feat in spectrogram_list:
72
+ current_len = feat.shape[-1]
73
+ padding = max_len_feat - current_len
74
+ if padding > 0:
75
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
76
+ else:
77
+ pad_feat = feat
78
+ pad_spectrogram.append(pad_feat)
79
+ batch["spectrogram"] = torch.stack(pad_spectrogram)
80
+
81
+ if "waveform" in features[0] and features[0]["waveform"] is not None:
82
+ waveform_list = [f["waveform"] for f in features]
83
+ max_len_wav = max(w.shape[-1] for w in waveform_list)
84
+ pad_waveforms = []
85
+ for wav in waveform_list:
86
+ current_len = wav.shape[-1]
87
+ padding = max_len_wav - current_len
88
+ if padding > 0:
89
+ if wav.ndim == 1:
90
+ wav = wav.unsqueeze(0)
91
+ pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
92
+ else:
93
+ pad_wav = wav
94
+ pad_waveforms.append(pad_wav)
95
+ batch["waveform"] = torch.stack(pad_waveforms)
96
+
97
+ if "label" in features[0] and features[0]["label"] is not None:
98
+ labels_list = [f["label"] for f in features]
99
+ max_len = max(len(l) for l in labels_list)
100
+ all_ids = []
101
+ all_labels = []
102
+
103
+ for label in labels_list:
104
+ label_list = label.tolist() if isinstance(label, torch.Tensor) else label
105
+ decoder_input = [bos_token_id] + label_list
106
+ label_eos = label_list + [pad_token_id]
107
+ input_len = max_len + 1 - len(decoder_input)
108
+ label_len = max_len + 1 - len(label_eos)
109
+ padded_input = decoder_input + [pad_token_id] * input_len
110
+ padded_labels = label_eos + [pad_token_id] * label_len
111
+ all_ids.append(padded_input)
112
+ all_labels.append(padded_labels)
113
+ batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
114
+ batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
115
+
116
+ if "pitch" in features[0] and features[0]["pitch"] is not None:
117
+ pitch_list = [f["pitch"] for f in features]
118
+ max_len_pitch = max(e.shape[-1] for e in pitch_list)
119
+ pad_pitch = []
120
+ for pitch in pitch_list:
121
+ current_len = pitch.shape[-1]
122
+ padding = max_len_pitch - current_len
123
+ if padding > 0:
124
+ pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
125
+ else:
126
+ pad_pitch_item = pitch
127
+ pad_pitch.append(pad_pitch_item)
128
+ batch["pitch"] = torch.stack(pad_pitch)
129
+
130
+ if "f0" in features[0] and features[0]["f0"] is not None:
131
+ input_ids_batch = batch.get("input_ids", None)
132
+ if input_ids_batch is not None:
133
+ target_length = input_ids_batch.shape[-1]
134
+ aligned_list = []
135
+ original_list = []
136
+ for feature in features:
137
+ f0 = feature["f0"]
138
+ original_list.append(f0)
139
+ if f0.shape[-1] != target_length:
140
+ aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
141
+ else:
142
+ aligned_f0 = f0
143
+ aligned_list.append(aligned_f0)
144
+ batch["f0d"] = torch.stack(aligned_list)
145
+ batch["f0"] = torch.stack(original_list)
146
+
147
+ if "envelope" in features[0] and features[0]["envelope"] is not None:
148
+ env_list = [f["envelope"] for f in features]
149
+ max_len = max(f.shape[-1] for f in env_list)
150
+ pad_env = []
151
+ for feat in env_list:
152
+ current_len = feat.shape[-1]
153
+ padding = max_len_feat - current_len
154
+ if padding > 0:
155
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
156
+ else:
157
+ pad_feat = feat
158
+ pad_env.append(pad_feat)
159
+ batch["envelope"] = torch.stack(pad_env)
160
+
161
+ if "phase" in features[0] and features[0]["phase"] is not None:
162
+ ph_list = [f["phase"] for f in features]
163
+ max_len = max(f.shape[-1] for f in ph_list)
164
+ pad_ph = []
165
+ for feat in ph_list:
166
+ current_len = feat.shape[-1]
167
+ padding = max_len_feat - current_len
168
+ if padding > 0:
169
+ pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
170
+ else:
171
+ pad_feat = feat
172
+ pad_ph.append(pad_feat)
173
+ batch["phase"] = torch.stack(pad_ph)
174
+ return batch
175
+
176
+ def hilbert_transform(x):
177
+ N = x.shape[-1]
178
+ xf = torch.fft.rfft(x)
179
+ h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
180
+ if N % 2 == 0:
181
+ h[0] = h[N//2] = 1
182
+ h[1:N//2] = 2
183
+ else:
184
+ h[0] = 1
185
+ h[1:(N+1)//2] = 2
186
+ return torch.fft.irfft(xf * h, n=N)
187
+
188
+ def analytic_signal(x):
189
+ return x + 1j * hilbert_transform(x)
190
+
191
+ def hilbert_transform_2d(x, dim=-1):
192
+ N = x.shape[dim]
193
+ if dim == -1 or dim == len(x.shape) - 1:
194
+ xf = torch.fft.rfft(x)
195
+ else:
196
+ xf = torch.fft.rfft(x, dim=dim)
197
+ h_shape = [1] * len(x.shape)
198
+ h_shape[dim] = N // 2 + 1
199
+ h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
200
+ if dim == -1 or dim == len(x.shape) - 1:
201
+ if N % 2 == 0:
202
+ h[..., 0] = h[..., -1] = 1
203
+ h[..., 1:-1] = 2
204
+ else:
205
+ h[..., 0] = 1
206
+ h[..., 1:] = 2
207
+ else:
208
+ pass
209
+ return torch.fft.irfft(xf * h, n=N, dim=dim)
210
+
211
+ def hilbert_transform_true_2d(x):
212
+ xf = torch.fft.rfft2(x)
213
+ h1, h2 = torch.meshgrid(
214
+ torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
215
+ torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
216
+ indexing='ij')
217
+ h = -1j / (math.pi * (h1 + 1j*h2))
218
+ h[0, 0] = 0
219
+ return torch.fft.irfft2(xf * h.to(x.device))
220
+
221
+ def process_spectrogram_with_hilbert(spec):
222
+ analytic = spec + 1j * hilbert_transform(spec)
223
+ envelope = torch.abs(analytic)
224
+ phase = torch.angle(analytic)
225
+ return envelope, phase
226
+
227
+ def load_wave(wave_data, sample_rate):
228
+ if isinstance(wave_data, str):
229
+ waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
230
+ elif isinstance(wave_data, dict):
231
+ waveform = torch.tensor(data=wave_data["array"]).float()
232
+ sr = wave_data["sampling_rate"]
233
+ else:
234
+ raise TypeError("Invalid wave_data format.")
235
+
236
+ if waveform.dim() == 1:
237
+ waveform = waveform.unsqueeze(0)
238
+
239
+ if sr != sample_rate:
240
+ original_length = waveform.shape[1]
241
+ target_length = int(original_length * (sample_rate / sr))
242
+
243
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
244
+ waveform = resampler(waveform)
245
+
246
+ return waveform.flatten()
247
+
248
+ def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
249
+ hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
250
+ pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
251
+ norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
252
+
253
+ dtype = torch.float32
254
+ device = torch.device("cuda:0")
255
+ audio = batch["audio"]
256
+ sampling_rate = audio["sampling_rate"]
257
+ sr = audio["sampling_rate"]
258
+ wav = load_wave(wave_data=audio, sample_rate=sr)
259
+
260
+ if spectrogram:
261
+ transform = torchaudio.transforms.MelSpectrogram(
262
+ f_max=fmax,
263
+ f_min=fmin,
264
+ n_mels=n_mels,
265
+ sample_rate=sr,
266
+ n_fft=n_fft,
267
+ hop_length=hop_length,
268
+ norm=norm,
269
+ normalized=normalized,
270
+ power=power,
271
+ center=center,
272
+ mel_scale=mel_scale,
273
+ window_fn=window_fn,
274
+ pad_mode=pad_mode)
275
+
276
+ mel_spectrogram = transform(wav)
277
+ log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
278
+ log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
279
+ spec = (log_mel + 4.0) / 4.0
280
+ spec = torch.tensor(spec)
281
+ batch["spectrogram"] = spec
282
+
283
+ if hilbert:
284
+ envelope_list = []
285
+ phase_list = []
286
+
287
+ for ch_idx in range(spec.shape[0]):
288
+ envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
289
+ envelope_list.append(envelope)
290
+ phase_list.append(phase)
291
+
292
+ batch["envelope"] = torch.stack(envelope_list)
293
+ batch["phase"] = torch.stack(phase_list)
294
+
295
+ wav_1d = wav.unsqueeze(0)
296
+
297
+ if waveforms:
298
+ batch["waveform"] = wav_1d
299
+
300
+ if pitch:
301
+ wav_np = wav.numpy().astype(np.float64)
302
+ f0, t = pw.dio(wav_np, sampling_rate,
303
+ frame_period=hop_length/sampling_rate*1000)
304
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
305
+ f0 = torch.from_numpy(f0).float()
306
+ batch["pitch"] = f0.unsqueeze(0)
307
+
308
+ if frequency:
309
+ wav_np = wav.numpy().astype(np.float64)
310
+ f0, t = pw.dio(wav_np, sampling_rate,
311
+ frame_period=hop_length/sampling_rate*1000)
312
+ f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
313
+ f0 = f0
314
+ batch["f0"] = torch.from_numpy(f0).float()
315
+
316
+ if spectrogram and waveforms and pitch:
317
+ spec_mean = batch["spectrogram"].mean()
318
+ spec_std = batch["spectrogram"].std() + 1e-6
319
+ batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
320
+
321
+ wav_mean = batch["waveform"].mean()
322
+ wav_std = batch["waveform"].std() + 1e-6
323
+ batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
324
+
325
+ if batch["pitch"].max() > 1.0:
326
+ pitch_min = 50.0
327
+ pitch_max = 600.0
328
+ batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
329
+
330
+ batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
331
+ return batch
332
+
333
+ def compute_metrics(eval_pred, compute_result: bool = True,
334
+ print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
335
+
336
+ pred_logits = eval_pred.predictions
337
+ label_ids = eval_pred.label_ids
338
+
339
+ if hasattr(pred_logits, "cpu"):
340
+ pred_logits = pred_logits.cpu()
341
+ if hasattr(label_ids, "cpu"):
342
+ label_ids = label_ids.cpu()
343
+ if isinstance(pred_logits, tuple):
344
+ pred_ids = pred_logits[0]
345
+ else:
346
+ pred_ids = pred_logits
347
+ if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
348
+ if not isinstance(pred_ids, torch.Tensor):
349
+ pred_ids = torch.tensor(pred_ids)
350
+ pred_ids = pred_ids.argmax(dim=-1)
351
+ pred_ids = pred_ids.tolist()
352
+
353
+ if hasattr(label_ids, "tolist"):
354
+ label_ids = label_ids.tolist()
355
+
356
+ label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
357
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
358
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
359
+
360
+ if print_pred:
361
+ for i in range(min(num_samples, len(pred_str))):
362
+ print(f"Preds: {pred_str[i]}")
363
+ print(f"Label: {label_str[i]}")
364
+ print(f"preds: {pred_ids[i]}")
365
+ print(f"label: {label_ids[i]}")
366
+ print("--------------------------------")
367
+
368
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
369
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
370
+ wer = 100 * metric.compute(predictions=pred_str, references=label_str)
371
+
372
+ if model is None:
373
+ global global_model
374
+ if 'global_model' in globals():
375
+ model = global_model
376
+
377
+ if model is not None:
378
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
379
+ if trainable_params > 0:
380
+ efficiency_score = (100 - wer) / trainable_params
381
+ else:
382
+ print("Warning: Zero trainable parameters detected")
383
+ efficiency_score = 0.0
384
+ else:
385
+ print("Warning: Model not available for parameter counting")
386
+ trainable_params = 0.0
387
+ efficiency_score = 0.0
388
+
389
+ if hasattr(wer, "item"):
390
+ wer = wer.item()
391
+
392
+ metrics = {
393
+ "wer": float(wer),
394
+ "trainable_params_M": float(trainable_params),
395
+ "efficiency_score": float(efficiency_score),
396
+ }
397
+
398
+ return metrics
399
+
400
+ logger = logging.getLogger(__name__)
401
+
402
+ def create_model(param: Dimensions) -> Echo:
403
+ model = Echo(param).to('cuda')
404
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
405
+ total_params = sum(p.numel() for p in model.parameters())
406
+ logger.info(f"Trainable parameters: {trainable_params:,}")
407
+ logger.info(f"Total parameters: {total_params:,}")
408
+ print(f"Trainable parameters: {trainable_params:,}")
409
+ print(f"Total parameters: {total_params:,}")
410
+
411
+ return model
412
+
413
+ def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
414
+ from tokenizers import Tokenizer
415
+ tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
416
+ orig_encode = tokenizer.encode
417
+ def enc(text, add_special_tokens=True):
418
+ ids = orig_encode(text).ids
419
+ if not add_special_tokens:
420
+ sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
421
+ ids = [id for id in ids if id not in sp_ids]
422
+ return ids
423
+ def bdec(ids_list, skip_special_tokens=True):
424
+ results = []
425
+ for ids in ids_list:
426
+ if skip_special_tokens:
427
+ ids = [id for id in ids if id not in [0, 1, 2]]
428
+ results.append(tokenizer.decode(ids))
429
+ return results
430
+ def save_pretrained(save_dir):
431
+ os.makedirs(save_dir, exist_ok=True)
432
+ tokenizer.save(f"{save_dir}/tokenizer.json")
433
+ tokenizer.encode = enc
434
+ tokenizer.batch_decode = bdec
435
+ tokenizer.save_pretrained = save_pretrained
436
+ tokenizer.pad_token_id = 0
437
+ tokenizer.bos_token_id = 1
438
+ tokenizer.eos_token_id = 2
439
+ return tokenizer
440
+
441
+ def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
442
+ if dataset_config is None:
443
+ dataset_config = {
444
+ "spectrogram": True,
445
+ "waveforms": True,
446
+ "pitch": True,
447
+ "frequency": True,
448
+ "downsamples": True,
449
+ "hop_length": 128,
450
+ "fmin": 50,
451
+ "fmax": 2000,
452
+ "n_mels": 128,
453
+ "n_fft": 1024,
454
+ "sampling_rate": 16000,
455
+ }
456
+
457
+ dataset = load_dataset(
458
+ "google/fleurs",
459
+ "en_us",
460
+ token=token,
461
+ trust_remote_code=True,
462
+ streaming=False)
463
+
464
+ dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
465
+
466
+ if sanity_check:
467
+ dataset = dataset["test"].take(10)
468
+ dataset = dataset.select_columns(["audio", "transcription"])
469
+ logger.info(f"Sanity dataset size: {dataset.num_rows}")
470
+ print(f"Sanity dataset size: {dataset.num_rows}")
471
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
472
+
473
+ dataset = dataset.map(
474
+ function=prepare_fn,
475
+ remove_columns=["audio", "transcription"]
476
+ ).with_format(type="torch")
477
+ train_dataset = dataset
478
+ test_dataset = dataset
479
+ else:
480
+ def filter_func(x):
481
+ return (0 < len(x["transcription"]) < 512 and
482
+ len(x["audio"]["array"]) > 0 and
483
+ len(x["audio"]["array"]) < 1500 * 160)
484
+
485
+ dataset = dataset.filter(filter_func).shuffle(seed=4)
486
+ logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
487
+ print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
488
+ prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
489
+ columns_to_remove = list(next(iter(dataset.values())).features)
490
+ train_dataset = dataset["train"]
491
+ test_dataset = dataset["test"].take(50)
492
+ logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
493
+
494
+ train_dataset = train_dataset.map(
495
+ function=prepare_fn,
496
+ remove_columns=columns_to_remove
497
+ ).with_format(type="torch")
498
+
499
+ test_dataset = test_dataset.map(
500
+ function=prepare_fn,
501
+ remove_columns=columns_to_remove
502
+ ).with_format(type="torch")
503
+
504
+ return train_dataset, test_dataset
505
+
506
+ def get_training_args(
507
+ log_dir: str,
508
+ batch_eval_metrics: bool = False,
509
+ max_steps: int = 10,
510
+ save_steps: int = 1000,
511
+ eval_steps: int = 1,
512
+ warmup_steps: int = 0,
513
+ num_train_epochs: int = 1,
514
+ logging_steps: int = 1,
515
+ eval_on_start: bool = False,
516
+ learning_rate: float = 1e-4,
517
+ weight_decay: float = 0.01,
518
+ max_grad_norm: float = 1.0,
519
+ ) -> Seq2SeqTrainingArguments:
520
+
521
+ return Seq2SeqTrainingArguments(
522
+ output_dir=log_dir,
523
+ per_device_train_batch_size=1,
524
+ per_device_eval_batch_size=1,
525
+ gradient_accumulation_steps=1,
526
+ eval_accumulation_steps=1,
527
+ tf32=True,
528
+ bf16=True,
529
+ eval_strategy="steps",
530
+ save_strategy="steps",
531
+ max_steps=max_steps,
532
+ save_steps=save_steps,
533
+ eval_steps=eval_steps,
534
+ warmup_steps=warmup_steps,
535
+ num_train_epochs=num_train_epochs,
536
+ logging_steps=logging_steps,
537
+ logging_dir=log_dir,
538
+ logging_strategy="steps",
539
+ report_to=["tensorboard"],
540
+ push_to_hub=False,
541
+ disable_tqdm=False,
542
+ save_total_limit=1,
543
+ label_names=["labels"],
544
+ optim="adamw_torch",
545
+ lr_scheduler_type="cosine",
546
+ learning_rate=learning_rate,
547
+ weight_decay=weight_decay,
548
+ save_safetensors=False,
549
+ eval_on_start=eval_on_start,
550
+ batch_eval_metrics=batch_eval_metrics,
551
+ max_grad_norm=max_grad_norm,
552
+ )
553
+
554
+ def main():
555
+
556
+ token = ""
557
+ log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
558
+ os.makedirs(name=log_dir, exist_ok=True)
559
+ tokenizer = setup_tokenizer(token)
560
+
561
+ def sanity(sanity: bool):
562
+
563
+ if sanity:
564
+ training_args = get_training_args(
565
+ log_dir,
566
+ batch_eval_metrics = False,
567
+ max_steps = 10,
568
+ save_steps = 0,
569
+ eval_steps = 1,
570
+ warmup_steps = 0,
571
+ logging_steps = 1,
572
+ eval_on_start = False,
573
+ learning_rate = 5e-6,
574
+ weight_decay = 0.01,
575
+ )
576
+ else:
577
+ training_args = get_training_args(
578
+ log_dir,
579
+ batch_eval_metrics = False,
580
+ max_steps = 1000,
581
+ save_steps = 1000,
582
+ eval_steps = 100,
583
+ warmup_steps = 100,
584
+ logging_steps = 10,
585
+ eval_on_start = False,
586
+ learning_rate = 2.5e-4,
587
+ weight_decay = 0.01,
588
+ )
589
+
590
+ return training_args
591
+
592
+ param = Dimensions(
593
+ mels=128,
594
+ aud_ctx=1500,
595
+ aud_head=4,
596
+ aud_dims=512,
597
+ aud_idx=4,
598
+ vocab=40000,
599
+ text_ctx=512,
600
+ text_head=4,
601
+ text_dims=512,
602
+ text_idx=4,
603
+ act="swish",
604
+ debug={},#{"encoder", "decoder", "residual", "rotary"},
605
+ cross_attn=True,
606
+ f0_rotary=False,
607
+ features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
608
+ )
609
+
610
+ sanity_check = False
611
+ training_args = sanity(sanity_check)
612
+ dataset_config = {
613
+ "spectrogram": True,
614
+ "waveforms": False,
615
+ "pitch": False,
616
+ "downsamples": False,
617
+ "frequency": True,
618
+ "hilbert": False,
619
+ "hop_length": 128,
620
+ "fmin": 150,
621
+ "fmax": 2000,
622
+ "n_mels": 128,
623
+ "n_fft": 1024,
624
+ "sampling_rate": 16000,
625
+ "pad_mode": "constant",
626
+ "center": True,
627
+ "power": 2.0,
628
+ "window_fn": torch.hann_window,
629
+ "mel_scale": "htk",
630
+ "norm": None,
631
+ "normalized": False}
632
+
633
+ model = create_model(param)
634
+
635
+ global global_model
636
+ global_model = model
637
+
638
+ metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
639
+ tokenizer=tokenizer, model=model)
640
+
641
+ print(f"{'Sanity check' if sanity_check else 'Training'} mode")
642
+ train_dataset, test_dataset = prepare_datasets(
643
+ tokenizer=tokenizer,
644
+ token=token,
645
+ sanity_check=sanity_check,
646
+ dataset_config=dataset_config)
647
+
648
+ trainer = Seq2SeqTrainer(
649
+ args=training_args,
650
+ model=model,
651
+ train_dataset=train_dataset,
652
+ eval_dataset=test_dataset,
653
+ data_collator=DataCollator(tokenizer=tokenizer),
654
+ compute_metrics=metrics_fn,
655
+ )
656
+
657
+ model.init_weights()
658
+ trainer.train()
659
+
660
+ if __name__ == "__main__":
661
+ main()
662
+