Gagan Bhatia commited on
Commit
fafca07
·
1 Parent(s): 056c147

Update model.py

Browse files
Files changed (1) hide show
  1. src/models/model.py +248 -0
src/models/model.py CHANGED
@@ -185,3 +185,251 @@ class LightningModel(LightningModule):
185
  input_ids=input_ids,
186
  attention_mask=attention_mask,
187
  decoder_attention_mask=labels_attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  input_ids=input_ids,
186
  attention_mask=attention_mask,
187
  decoder_attention_mask=labels_attention_mask,
188
+ labels=labels,
189
+ )
190
+ self.log("train_loss", loss, prog_bar=True, logger=True)
191
+ return loss
192
+
193
+ def validation_step(self, batch, batch_size):
194
+ """ validation step """
195
+ input_ids = batch["keywords_input_ids"]
196
+ attention_mask = batch["keywords_attention_mask"]
197
+ labels = batch["labels"]
198
+ labels_attention_mask = batch["labels_attention_mask"]
199
+
200
+ loss, outputs = self(
201
+ input_ids=input_ids,
202
+ attention_mask=attention_mask,
203
+ decoder_attention_mask=labels_attention_mask,
204
+ labels=labels,
205
+ )
206
+ self.log("val_loss", loss, prog_bar=True, logger=True)
207
+ return loss
208
+
209
+ def test_step(self, batch, batch_size):
210
+ """ test step """
211
+ input_ids = batch["keywords_input_ids"]
212
+ attention_mask = batch["keywords_attention_mask"]
213
+ labels = batch["labels"]
214
+ labels_attention_mask = batch["labels_attention_mask"]
215
+
216
+ loss, outputs = self(
217
+ input_ids=input_ids,
218
+ attention_mask=attention_mask,
219
+ decoder_attention_mask=labels_attention_mask,
220
+ labels=labels,
221
+ )
222
+
223
+ self.log("test_loss", loss, prog_bar=True, logger=True)
224
+ return loss
225
+
226
+ def configure_optimizers(self):
227
+ """ configure optimizers """
228
+ model = self.model
229
+ no_decay = ["bias", "LayerNorm.weight"]
230
+ optimizer_grouped_parameters = [
231
+ {
232
+ "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
233
+ "weight_decay": self.hparams.weight_decay,
234
+ },
235
+ {
236
+ "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
237
+ "weight_decay": 0.0,
238
+ },
239
+ ]
240
+ optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
241
+ self.opt = optimizer
242
+ return [optimizer]
243
+
244
+
245
+ class Summarization:
246
+ """ Custom Summarization class """
247
+
248
+ def __init__(self) -> None:
249
+ """ initiates Summarization class """
250
+ pass
251
+
252
+ def from_pretrained(self, model_name="t5-base") -> None:
253
+ """
254
+ loads T5/MT5 Model model for training/finetuning
255
+ Args:
256
+ model_name (str, optional): exact model architecture name, "t5-base" or "t5-large". Defaults to "t5-base".
257
+ """
258
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
259
+ self.model = T5ForConditionalGeneration.from_pretrained(
260
+ f"{model_name}", return_dict=True
261
+ )
262
+
263
+ def train(
264
+ self,
265
+ train_df: pd.DataFrame,
266
+ eval_df: pd.DataFrame,
267
+ source_max_token_len: int = 512,
268
+ target_max_token_len: int = 512,
269
+ batch_size: int = 8,
270
+ max_epochs: int = 5,
271
+ use_gpu: bool = True,
272
+ outputdir: str = "models",
273
+ early_stopping_patience_epochs: int = 0, # 0 to disable early stopping feature
274
+ ):
275
+ """
276
+ trains T5/MT5 model on custom dataset
277
+ Args:
278
+ train_df (pd.DataFrame): training datarame. Dataframe must have 2 column --> "input_text" and "output_text"
279
+ eval_df ([type], optional): validation datarame. Dataframe must have 2 column --> "input_text" and
280
+ "output_text"
281
+ source_max_token_len (int, optional): max token length of source text. Defaults to 512.
282
+ target_max_token_len (int, optional): max token length of target text. Defaults to 512.
283
+ batch_size (int, optional): batch size. Defaults to 8.
284
+ max_epochs (int, optional): max number of epochs. Defaults to 5.
285
+ use_gpu (bool, optional): if True, model uses gpu for training. Defaults to True.
286
+ outputdir (str, optional): output directory to save model checkpoints. Defaults to "outputs".
287
+ early_stopping_patience_epochs (int, optional): monitors val_loss on epoch end and stops training,
288
+ if val_loss does not improve after the specied number of epochs. set 0 to disable early stopping.
289
+ Defaults to 0 (disabled)
290
+ """
291
+ self.target_max_token_len = target_max_token_len
292
+ self.data_module = PLDataModule(
293
+ train_df,
294
+ eval_df,
295
+ self.tokenizer,
296
+ batch_size=batch_size,
297
+ source_max_token_len=source_max_token_len,
298
+ target_max_token_len=target_max_token_len,
299
+ )
300
+
301
+ self.T5Model = LightningModel(
302
+ tokenizer=self.tokenizer, model=self.model, output=outputdir
303
+ )
304
+
305
+ # checkpoint_callback = ModelCheckpoint(
306
+ # dirpath="checkpoints",
307
+ # filename="best-checkpoint-{epoch}-{train_loss:.2f}",
308
+ # save_top_k=-1,
309
+ # verbose=True,
310
+ # monitor="train_loss",
311
+ # mode="min",
312
+ # )
313
+
314
+ logger = MLFlowLogger(experiment_name="Summarization")
315
+
316
+ early_stop_callback = (
317
+ [
318
+ EarlyStopping(
319
+ monitor="val_loss",
320
+ min_delta=0.00,
321
+ patience=early_stopping_patience_epochs,
322
+ verbose=True,
323
+ mode="min",
324
+ )
325
+ ]
326
+ if early_stopping_patience_epochs > 0
327
+ else None
328
+ )
329
+
330
+ gpus = 1 if use_gpu else 0
331
+
332
+ trainer = Trainer(
333
+ logger=logger,
334
+ callbacks=early_stop_callback,
335
+ max_epochs=max_epochs,
336
+ gpus=gpus,
337
+ progress_bar_refresh_rate=5,
338
+ )
339
+
340
+ trainer.fit(self.T5Model, self.data_module)
341
+
342
+ def load_model(
343
+ self, model_dir: str = "../../models", use_gpu: bool = False
344
+ ):
345
+ """
346
+ loads a checkpoint for inferencing/prediction
347
+ Args:
348
+ model_type (str, optional): "t5" or "mt5". Defaults to "t5".
349
+ model_dir (str, optional): path to model directory. Defaults to "outputs".
350
+ use_gpu (bool, optional): if True, model uses gpu for inferencing/prediction. Defaults to True.
351
+ """
352
+ self.model = T5ForConditionalGeneration.from_pretrained(f"{model_dir}")
353
+ self.tokenizer = T5Tokenizer.from_pretrained(f"{model_dir}")
354
+
355
+ if use_gpu:
356
+ if torch.cuda.is_available():
357
+ self.device = torch.device("cuda")
358
+ else:
359
+ raise Exception("exception ---> no gpu found. set use_gpu=False, to use CPU")
360
+ else:
361
+ self.device = torch.device("cpu")
362
+
363
+ self.model = self.model.to(self.device)
364
+
365
+ def save_model(
366
+ self,
367
+ model_dir="../../models"
368
+ ):
369
+ """
370
+ Save model to dir
371
+ :param model_dir:
372
+ :return: model is saved
373
+ """
374
+ path = f"{model_dir}"
375
+ self.tokenizer.save_pretrained(path)
376
+ self.model.save_pretrained(path)
377
+
378
+ def predict(
379
+ self,
380
+ source_text: str,
381
+ max_length: int = 512,
382
+ num_return_sequences: int = 1,
383
+ num_beams: int = 2,
384
+ top_k: int = 50,
385
+ top_p: float = 0.95,
386
+ do_sample: bool = True,
387
+ repetition_penalty: float = 2.5,
388
+ length_penalty: float = 1.0,
389
+ early_stopping: bool = True,
390
+ skip_special_tokens: bool = True,
391
+ clean_up_tokenization_spaces: bool = True,
392
+ ):
393
+ """
394
+ generates prediction for T5/MT5 model
395
+ Args:
396
+ source_text (str): any text for generating predictions
397
+ max_length (int, optional): max token length of prediction. Defaults to 512.
398
+ num_return_sequences (int, optional): number of predictions to be returned. Defaults to 1.
399
+ num_beams (int, optional): number of beams. Defaults to 2.
400
+ top_k (int, optional): Defaults to 50.
401
+ top_p (float, optional): Defaults to 0.95.
402
+ do_sample (bool, optional): Defaults to True.
403
+ repetition_penalty (float, optional): Defaults to 2.5.
404
+ length_penalty (float, optional): Defaults to 1.0.
405
+ early_stopping (bool, optional): Defaults to True.
406
+ skip_special_tokens (bool, optional): Defaults to True.
407
+ clean_up_tokenization_spaces (bool, optional): Defaults to True.
408
+ Returns:
409
+ list[str]: returns predictions
410
+ """
411
+ input_ids = self.tokenizer.encode(
412
+ source_text, return_tensors="pt", add_special_tokens=True
413
+ )
414
+
415
+ input_ids = input_ids.to(self.device)
416
+ generated_ids = self.model.generate(
417
+ input_ids=input_ids,
418
+ num_beams=num_beams,
419
+ max_length=max_length,
420
+ repetition_penalty=repetition_penalty,
421
+ length_penalty=length_penalty,
422
+ early_stopping=early_stopping,
423
+ top_p=top_p,
424
+ top_k=top_k,
425
+ num_return_sequences=num_return_sequences,
426
+ )
427
+ preds = [
428
+ self.tokenizer.decode(
429
+ g,
430
+ skip_special_tokens=skip_special_tokens,
431
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
432
+ )
433
+ for g in generated_ids
434
+ ]
435
+ return preds