HemanM commited on
Commit
718e236
·
verified ·
1 Parent(s): f8b9865

Create prep_datasets.py

Browse files
Files changed (1) hide show
  1. prep_datasets.py +58 -0
prep_datasets.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # prep_datasets.py
2
+ # One-time exporter: saves PIQA + HellaSwag to clean JSONL for offline use.
3
+
4
+ from datasets import load_dataset
5
+ import json, os
6
+
7
+ OUT_DIR = "data"
8
+ os.makedirs(OUT_DIR, exist_ok=True)
9
+
10
+ def write_jsonl(path, rows):
11
+ with open(path, "w", encoding="utf-8") as f:
12
+ for r in rows:
13
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
14
+
15
+ # --- PIQA ---
16
+ print("Downloading PIQA…")
17
+ piqa = load_dataset("piqa")
18
+ def piqa_clean(split):
19
+ out = []
20
+ for ex in split:
21
+ out.append({
22
+ "goal": ex.get("goal") or "",
23
+ "sol1": ex.get("sol1") or "",
24
+ "sol2": ex.get("sol2") or "",
25
+ "label": int(ex.get("label", 0))
26
+ })
27
+ return out
28
+
29
+ piqa_train = piqa["train"]
30
+ piqa_valid = piqa["validation"]
31
+
32
+ print("Writing PIQA JSONL…")
33
+ write_jsonl(os.path.join(OUT_DIR, "piqa_train.jsonl"), piqa_clean(piqa_train))
34
+ write_jsonl(os.path.join(OUT_DIR, "piqa_valid.jsonl"), piqa_clean(piqa_valid))
35
+
36
+ # --- HellaSwag ---
37
+ print("Downloading HellaSwag…")
38
+ hs = load_dataset("hellaswag")
39
+ def hs_clean(split):
40
+ out = []
41
+ for ex in split:
42
+ out.append({
43
+ # keep both ctx and ctx_a to be safe (some variants use both)
44
+ "ctx": ex.get("ctx") or "",
45
+ "ctx_a": ex.get("ctx_a") or "",
46
+ "endings": list(ex.get("endings") or []),
47
+ "label": int(ex.get("label", 0))
48
+ })
49
+ return out
50
+
51
+ hs_train = hs["train"]
52
+ hs_valid = hs["validation"]
53
+
54
+ print("Writing HellaSwag JSONL…")
55
+ write_jsonl(os.path.join(OUT_DIR, "hellaswag_train.jsonl"), hs_clean(hs_train))
56
+ write_jsonl(os.path.join(OUT_DIR, "hellaswag_valid.jsonl"), hs_clean(hs_valid))
57
+
58
+ print("Done. Files created in ./data")