File size: 3,961 Bytes
bd0a813
 
 
1160793
3f8f152
9ff4511
bd0a813
9ff4511
3f8f152
 
 
9ff4511
 
1160793
bd0a813
 
9ff4511
95d8ea8
 
9ff4511
 
 
 
 
 
95d8ea8
bd0a813
1160793
9ff4511
bd0a813
9ff4511
3f8f152
95d8ea8
9ff4511
1160793
3f8f152
1160793
 
 
 
 
bd0a813
9ff4511
bd0a813
9ff4511
1160793
 
 
 
 
9ff4511
1160793
 
 
 
bd0a813
1160793
bd0a813
1160793
 
 
9ff4511
1160793
 
 
bd0a813
1160793
 
 
 
bd0a813
1160793
95d8ea8
 
 
 
 
 
 
1160793
 
 
 
 
 
 
95d8ea8
1160793
 
95d8ea8
1160793
 
95d8ea8
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
from torch.utils.data import DataLoader
import omegaconf
from omegaconf import DictConfig
import wandb

from checkpoing_saver import CheckpointSaver
from denoisers import get_model
from optimizers import get_optimizer
from losses import get_loss
from datasets import get_datasets
from testing.metrics import Metrics
from datasets.minimal import Minimal


def train(cfg: DictConfig):
    device = torch.device(f'cuda:{cfg.gpu}' if torch.cuda.is_available() else 'cpu')

    wandb.login(key=cfg['wandb']['api_key'], host=cfg['wandb']['host'])
    wandb.init(project=cfg['wandb']['project'],
               notes=cfg['wandb']['notes'],
               tags=cfg['wandb']['tags'],
               config=omegaconf.OmegaConf.to_container(
        cfg, resolve=True, throw_on_missing=True))
    wandb.run.name = cfg['wandb']['run_name']

    checkpoint_saver = CheckpointSaver(dirpath=cfg['training']['model_save_path'], run_name=wandb.run.name)
    metrics = Metrics(rate=cfg['dataloader']['sample_rate'])

    model = get_model(cfg['model']).to(device)
    optimizer = get_optimizer(model.parameters(), cfg['optimizer'])
    loss_fn = get_loss(cfg['loss'], device)
    train_dataset, valid_dataset = get_datasets(cfg)
    minimal_dataset = Minimal(cfg)

    dataloaders = {
        'train':  DataLoader(train_dataset, batch_size=cfg['dataloader']['train_batch_size'], shuffle=True),
        'val': DataLoader(valid_dataset, batch_size=cfg['dataloader']['valid_batch_size'], shuffle=True),
        'minimal': DataLoader(minimal_dataset)
    }

    wandb.watch(model, log_freq=100)

    for epoch in range(cfg['training']['num_epochs']):
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss, running_pesq, running_stoi = 0.0, 0.0, 0.0
            for i, (inputs, labels) in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_metrics = metrics.calculate(denoised=outputs, clean=labels)
                running_loss += loss.item() * inputs.size(0)
                running_pesq += running_metrics['PESQ']
                running_stoi += running_metrics['STOI']

                if phase == 'train' and i % cfg['wandb']['log_interval'] == 0:
                    wandb.log({"train_loss": running_loss / (i + 1) / inputs.size(0),
                               "train_pesq": running_pesq / (i + 1) / inputs.size(0),
                               "train_stoi": running_stoi / (i + 1) / inputs.size(0)})

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            eposh_pesq = running_pesq / len(dataloaders[phase].dataset)
            eposh_stoi = running_stoi / len(dataloaders[phase].dataset)

            wandb.log({f"{phase}_loss": epoch_loss,
                       f"{phase}_pesq": eposh_pesq,
                       f"{phase}_stoi": eposh_stoi})

            if phase == 'val':
                for i, (wav, rate) in enumerate(dataloaders['minimal']):
                    prediction = model(wav.to(device))
                    wandb.log({
                        f"{i}_example": wandb.Audio(
                            prediction.detach().cpu().numpy()[0][0],
                            sample_rate=rate)})

                checkpoint_saver(model, epoch, metric_val=eposh_pesq,
                                 optimizer=optimizer, loss=epoch_loss)


if __name__ == "__main__":
    pass