Gagan Bhatia commited on
Commit
3c8cb17
·
1 Parent(s): 4c60028

Update train_model.py

Browse files
Files changed (1) hide show
  1. src/models/train_model.py +3 -3
src/models/train_model.py CHANGED
@@ -16,9 +16,9 @@ def train_model():
16
  eval_df = pd.read_csv('data/processed/validation.csv')
17
 
18
  model = Summarization()
19
- model.from_pretrained('t5','t5-base')
20
- model.train(train_df=train_df, eval_df=eval_df, batch_size=4, max_epochs=3, use_gpu=True)
21
- model.save_model()
22
 
23
 
24
  if __name__ == '__main__':
 
16
  eval_df = pd.read_csv('data/processed/validation.csv')
17
 
18
  model = Summarization()
19
+ model.from_pretrained(model_type=params['model_type'], model_name=params['model_name'])
20
+
21
+ model.train(train_df=train_df, eval_df=eval_df,
22
 
23
 
24
  if __name__ == '__main__':