Gagan Bhatia commited on
Commit
5288717
·
1 Parent(s): c59f6db

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +1 -0
src/models/model.py CHANGED
@@ -362,6 +362,7 @@ class Summarization:
362
  elif model_type == "mt5":
363
  self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
364
  self.model = MT5ForConditionalGeneration.from_pretrained(
 
365
 
366
  if use_gpu:
367
  if torch.cuda.is_available():
 
362
  elif model_type == "mt5":
363
  self.tokenizer = MT5Tokenizer.from_pretrained(f"{model_dir}")
364
  self.model = MT5ForConditionalGeneration.from_pretrained(
365
+ f"{model_dir}", return_dict=True
366
 
367
  if use_gpu:
368
  if torch.cuda.is_available():