from abc import ABC, abstractmethod from datasets import load_dataset, DatasetDict from tasks.utils.evaluation import TextEvaluationRequest class DataLoader(ABC): @abstractmethod def get_train_dataset(self): pass @abstractmethod def get_test_dataset(self): pass class TextDataLoader(DataLoader): def __init__(self, request: TextEvaluationRequest = TextEvaluationRequest(), light: bool = False): self.label_mapping = { "0_not_relevant": 0, "1_not_happening": 1, "2_not_human": 2, "3_not_bad": 3, "4_solutions_harmful_unnecessary": 4, "5_science_unreliable": 5, "6_proponents_biased": 6, "7_fossil_fuels_needed": 7 } # Load the dataset, and convert string labels to integers dataset = load_dataset(request.dataset_name) dataset = dataset.map(lambda x: {"label": self.label_mapping[x["label"]]}) self.dataset = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed) # Create a smaller version of the dataset for quick testing if light: self.dataset = DatasetDict({ "train": self.dataset["train"].shuffle(seed=42).select(range(10)), "test": self.dataset["test"].shuffle(seed=42).select(range(2)) }) def get_train_dataset(self): return self.dataset["train"] def get_test_dataset(self): return self.dataset["test"] def get_label_to_id_mapping(self): return self.label_mapping def get_id_to_label_mapping(self): return {v: k for k, v in self.label_mapping.items()}