File size: 572 Bytes
cdce7a5
 
cb1227c
 
 
297c713
 
 
 
 
cdce7a5
 
 
d51b694
4c60028
 
d51b694
 
958efd4
d51b694
7f5fe96
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import yaml

from src.models.model import Summarization
import pandas as pd


def train_model():
    """
    Train the model
    """
    with open("params.yml") as f:
        params = yaml.safe_load(f)

    # Load the data
    train_df = pd.read_csv('data/processed/train.csv')
    eval_df = pd.read_csv('data/processed/validation.csv')

    model = Summarization()
    model.from_pretrained('t5','t5-base')
    model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
    model.save_model()


if __name__ == '__main__':
    train_model()