liveEvolutionEVO / data_utils.py
HemanM's picture
Update data_utils.py
c269e9c verified
# data_utils.py — local-first dataset loaders + hashing vectorizer
from typing import List, Tuple
import os, json
import numpy as np
try:
from datasets import load_dataset # optional, used only as fallback
except Exception:
load_dataset = None
# -----------------------------
# Hashing vectorizer (unigram + bigram)
# -----------------------------
def hash_vectorize(texts: List[str], n_features: int = 4096, seed: int = 1234) -> np.ndarray:
n = len(texts)
X = np.zeros((n, n_features), dtype=np.float32)
for i, t in enumerate(texts):
if not t:
continue
toks = t.lower().split()
prev = None
for tok in toks:
h1 = hash(tok) % n_features
X[i, h1] += 1.0
if prev is not None:
bg = prev + "_" + tok
h2 = hash(bg) % n_features
X[i, h2] += 1.0
prev = tok
norm = float(np.linalg.norm(X[i])) + 1e-8
X[i] /= norm
return X
# -----------------------------
# Utilities for local JSONL
# -----------------------------
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
def _read_jsonl(path: str):
out = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
out.append(json.loads(line))
return out
def _has_local(*names: str) -> bool:
return all(os.path.exists(os.path.join(DATA_DIR, n)) for n in names)
# -----------------------------
# PIQA loader (pair-expanded)
# -----------------------------
def load_piqa(subset: int = 800, seed: int = 42):
"""
Returns:
Xtr_txt, ytr, Xva_txt, yva
For each original PIQA example, we emit TWO rows:
[goal + sol1] with label 1 if sol1 correct, else 0
[goal + sol2] with label 1 if sol2 correct, else 0
"""
rng = np.random.RandomState(seed)
# Prefer local
tr_name, va_name = "piqa_train.jsonl", "piqa_valid.jsonl"
if _has_local(tr_name, va_name):
tr = _read_jsonl(os.path.join(DATA_DIR, tr_name))
va = _read_jsonl(os.path.join(DATA_DIR, va_name))
else:
# Fallback to datasets (if available)
if load_dataset is None:
raise RuntimeError("PIQA local files not found and 'datasets' not installed.")
ds = load_dataset("piqa")
tr, va = list(ds["train"]), list(ds["validation"])
# subsample
idx_tr = rng.choice(len(tr), size=min(subset, len(tr)), replace=False)
idx_va = rng.choice(len(va), size=min(max(subset // 4, 200), len(va)), replace=False)
def pack(rows, idxs):
X_text, y = [], []
for k in idxs:
p = rows[k]
stem = (p.get("goal") or "").strip()
sol1 = (p.get("sol1") or "").strip()
sol2 = (p.get("sol2") or "").strip()
label = int(p.get("label", 0))
X_text.append(f"{stem} {sol1}"); y.append(1 if label == 0 else 0)
X_text.append(f"{stem} {sol2}"); y.append(1 if label == 1 else 0)
return X_text, np.array(y, dtype=np.int64)
Xtr_txt, ytr = pack(tr, idx_tr)
Xva_txt, yva = pack(va, idx_va)
return Xtr_txt, ytr, Xva_txt, yva
# -----------------------------
# HellaSwag loader (4-way expanded)
# -----------------------------
def load_hellaswag(subset: int = 800, seed: int = 42):
"""
Returns:
Xtr_txt, ytr, Xva_txt, yva
For each example, we emit FOUR rows:
[context + ending_i] with label 1 if i is the correct ending else 0
"""
rng = np.random.RandomState(seed)
tr_name, va_name = "hellaswag_train.jsonl", "hellaswag_valid.jsonl"
if _has_local(tr_name, va_name):
tr = _read_jsonl(os.path.join(DATA_DIR, tr_name))
va = _read_jsonl(os.path.join(DATA_DIR, va_name))
else:
if load_dataset is None:
raise RuntimeError("HellaSwag local files not found and 'datasets' not installed.")
ds = load_dataset("hellaswag")
tr, va = list(ds["train"]), list(ds["validation"])
idx_tr = rng.choice(len(tr), size=min(subset, len(tr)), replace=False)
idx_va = rng.choice(len(va), size=min(max(subset // 4, 200), len(va)), replace=False)
def pack(rows, idxs):
X_text, y = [], []
for k in idxs:
p = rows[k]
ctx = f"{(p.get('ctx') or '')} {(p.get('ctx_a') or '')}".strip()
endings = p.get("endings") or []
label = int(p.get("label", 0))
for i, e in enumerate(endings):
X_text.append(f"{ctx} {e}".strip())
y.append(1 if i == label else 0)
return X_text, np.array(y, dtype=np.int64)
Xtr_txt, ytr = pack(tr, idx_tr)
Xva_txt, yva = pack(va, idx_va)
return Xtr_txt, ytr, Xva_txt, yva