File size: 2,444 Bytes
33d4721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch


class TextClassificationDataset:
    """
    A dataset class for text classification tasks.

    Args:
        data (list): The dataset containing text and target columns.
        tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
        config (object): Configuration object containing dataset parameters.

    Attributes:
        data (list): The dataset containing text and target columns.
        tokenizer (PreTrainedTokenizer): The tokenizer to preprocess the text data.
        config (object): Configuration object containing dataset parameters.
        text_column (str): The name of the column containing text data.
        target_column (str): The name of the column containing target labels.

    Methods:
        __len__(): Returns the number of samples in the dataset.
        __getitem__(item): Returns a dictionary containing tokenized input ids, attention mask, token type ids (if available), and target labels for the given item index.
    """

    def __init__(self, data, tokenizer, config):
        self.data = data
        self.tokenizer = tokenizer
        self.config = config
        self.text_column = self.config.text_column
        self.target_column = self.config.target_column

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        text = str(self.data[item][self.text_column])
        target = self.data[item][self.target_column]
        target = int(target)
        inputs = self.tokenizer(
            text,
            max_length=self.config.max_seq_length,
            padding="max_length",
            truncation=True,
        )

        ids = inputs["input_ids"]
        mask = inputs["attention_mask"]

        if "token_type_ids" in inputs:
            token_type_ids = inputs["token_type_ids"]
        else:
            token_type_ids = None

        if token_type_ids is not None:
            return {
                "input_ids": torch.tensor(ids, dtype=torch.long),
                "attention_mask": torch.tensor(mask, dtype=torch.long),
                "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
                "labels": torch.tensor(target, dtype=torch.long),
            }
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(mask, dtype=torch.long),
            "labels": torch.tensor(target, dtype=torch.long),
        }