File size: 1,198 Bytes
adeafa0
f8a2915
cdce7a5
 
dfda3c6
cb1227c
 
297c713
 
 
 
 
cdce7a5
 
 
d51b694
4c60028
 
d51b694
d8baba8
5cb4c53
 
d51b694
3c8cb17
 
 
72b8e68
 
 
f2e152e
 
7f5fe96
45e8538
 
 
7d39284
05f29e7
 
7f5fe96
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import json

import yaml

from model import Summarization
import pandas as pd


def train_model():
    """
    Train the model
    """
    with open("params.yml") as f:
        params = yaml.safe_load(f)

    # Load the data
    train_df = pd.read_csv('data/processed/train.csv')
    eval_df = pd.read_csv('data/processed/validation.csv')

    train_df = train_df.sample(frac=params['split'], replace=True, random_state=1)
    eval_df = eval_df.sample(frac=params['split'], replace=True, random_state=1)

    model = Summarization()
    model.from_pretrained(model_type=params['model_type'], model_name=params['model_name'])

    model.train(train_df=train_df, eval_df=eval_df,
                batch_size=params['batch_size'], max_epochs=params['epochs'],
                use_gpu=params['use_gpu'], learning_rate=float(params['learning_rate']),
                num_workers=int(params['num_workers']))

    model.save_model(model_dir=params['model_dir'])

    with open('wandb/latest-run/files/wandb-summary.json') as json_file:
        data = json.load(json_file)

    with open('reports/training_metrics.txt', 'w') as fp:
        json.dump(data, fp)


if __name__ == '__main__':
    train_model()