File size: 1,508 Bytes
0901f7b
9d42f47
 
 
 
0681bc6
40bc3ae
1e660fe
a208e00
99b2dac
691657e
30eb466
d275124
64226c2
 
187bb52
 
171817a
 
 
74e821f
 
 
ea2c6ba
 
 
336121a
 
 
5c300b0
 
 
8e6d41d
 
 
08eb467
 
168d045
 
df20fa3
 
 
b62e2d9
 
fae8953
 
15505de
 
 
5686d0a
 
 
 
d50b787
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import time
import torch
import numpy as np
import pandas as pd
from datasets import load_metric
from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer,
)
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningModule

torch.cuda.empty_cache()
pl.seed_everything(42)


class DataModule(Dataset):
    """
    Data Module for pytorch
    """

    def __init__(
            self,
            data: pd.DataFrame,
            tokenizer: T5Tokenizer,
            source_max_token_len: int = 512,
            target_max_token_len: int = 512,
    ):
        """
        :param data:
        :param tokenizer:
        :param source_max_token_len:
        :param target_max_token_len:
        """
        self.data = data
        self.target_max_token_len = target_max_token_len
        self.source_max_token_len = source_max_token_len
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]

        input_encoding = self.tokenizer(
            data_row["input_text"],
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation=True,