Spaces:
Runtime error
Runtime error
File size: 1,508 Bytes
0901f7b 9d42f47 0681bc6 40bc3ae 1e660fe a208e00 99b2dac 691657e 30eb466 d275124 64226c2 187bb52 171817a 74e821f ea2c6ba 336121a 5c300b0 8e6d41d 08eb467 168d045 df20fa3 b62e2d9 fae8953 15505de 5686d0a d50b787 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import time
import torch
import numpy as np
import pandas as pd
from datasets import load_metric
from transformers import (
AdamW,
T5ForConditionalGeneration,
T5TokenizerFast as T5Tokenizer,
)
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import LightningDataModule
from pytorch_lightning import LightningModule
torch.cuda.empty_cache()
pl.seed_everything(42)
class DataModule(Dataset):
"""
Data Module for pytorch
"""
def __init__(
self,
data: pd.DataFrame,
tokenizer: T5Tokenizer,
source_max_token_len: int = 512,
target_max_token_len: int = 512,
):
"""
:param data:
:param tokenizer:
:param source_max_token_len:
:param target_max_token_len:
"""
self.data = data
self.target_max_token_len = target_max_token_len
self.source_max_token_len = source_max_token_len
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, index: int):
data_row = self.data.iloc[index]
input_encoding = self.tokenizer(
data_row["input_text"],
max_length=self.source_max_token_len,
padding="max_length",
truncation=True,
|