galb-dai's picture
Remove some unused code/imports.
416ebf1
raw
history blame
1.54 kB
import functools
import time
from datasets import load_dataset
from src.envs import CODE_PROBLEMS_REPO, RESULTS_REPO, SUBMISSIONS_REPO, TOKEN
from src.logger import get_logger
logger = get_logger(__name__)
class F1Data:
def __init__(
self,
cp_ds_name: str,
sub_ds_name: str,
res_ds_name: str,
split: str = "hard",
):
self.cp_dataset_name = cp_ds_name
self.submissions_dataset_name = sub_ds_name
self.results_dataset_name = res_ds_name
self.split = split
self.code_problems = None
self._initialize()
def _initialize(self):
logger.info(f"Initialize F1Data TOKEN='{TOKEN}'")
start_time = time.monotonic()
cp_ds = load_dataset(self.cp_dataset_name, split=self.split, token=TOKEN)
logger.info(
"Loaded code-problems dataset from %s in %f sec",
self.cp_dataset_name,
time.monotonic() - start_time,
)
self.code_problems: dict[str, str] = {r["id"]: r["code_problem"] for r in cp_ds}
logger.info(f"Loaded %d code problems {len(self.code_problems)}")
@functools.cached_property
def code_problem_ids(self) -> set[str]:
return set(self.code_problems.keys())
if __name__ == "__main__":
split = "hard"
f1_data = F1Data(cp_ds_name=CODE_PROBLEMS_REPO, sub_ds_name=SUBMISSIONS_REPO, res_ds_name=RESULTS_REPO, split=split)
print(f"Found {len(f1_data.code_problem_ids)} code problems in {split} split of {f1_data.cp_dataset_name}")