Spaces:
Sleeping
Sleeping
# data_utils.py — local-first dataset loaders + hashing vectorizer | |
from typing import List, Tuple | |
import os, json | |
import numpy as np | |
try: | |
from datasets import load_dataset # optional, used only as fallback | |
except Exception: | |
load_dataset = None | |
# ----------------------------- | |
# Hashing vectorizer (unigram + bigram) | |
# ----------------------------- | |
def hash_vectorize(texts: List[str], n_features: int = 4096, seed: int = 1234) -> np.ndarray: | |
n = len(texts) | |
X = np.zeros((n, n_features), dtype=np.float32) | |
for i, t in enumerate(texts): | |
if not t: | |
continue | |
toks = t.lower().split() | |
prev = None | |
for tok in toks: | |
h1 = hash(tok) % n_features | |
X[i, h1] += 1.0 | |
if prev is not None: | |
bg = prev + "_" + tok | |
h2 = hash(bg) % n_features | |
X[i, h2] += 1.0 | |
prev = tok | |
norm = float(np.linalg.norm(X[i])) + 1e-8 | |
X[i] /= norm | |
return X | |
# ----------------------------- | |
# Utilities for local JSONL | |
# ----------------------------- | |
DATA_DIR = os.path.join(os.path.dirname(__file__), "data") | |
def _read_jsonl(path: str): | |
out = [] | |
with open(path, "r", encoding="utf-8") as f: | |
for line in f: | |
if line.strip(): | |
out.append(json.loads(line)) | |
return out | |
def _has_local(*names: str) -> bool: | |
return all(os.path.exists(os.path.join(DATA_DIR, n)) for n in names) | |
# ----------------------------- | |
# PIQA loader (pair-expanded) | |
# ----------------------------- | |
def load_piqa(subset: int = 800, seed: int = 42): | |
""" | |
Returns: | |
Xtr_txt, ytr, Xva_txt, yva | |
For each original PIQA example, we emit TWO rows: | |
[goal + sol1] with label 1 if sol1 correct, else 0 | |
[goal + sol2] with label 1 if sol2 correct, else 0 | |
""" | |
rng = np.random.RandomState(seed) | |
# Prefer local | |
tr_name, va_name = "piqa_train.jsonl", "piqa_valid.jsonl" | |
if _has_local(tr_name, va_name): | |
tr = _read_jsonl(os.path.join(DATA_DIR, tr_name)) | |
va = _read_jsonl(os.path.join(DATA_DIR, va_name)) | |
else: | |
# Fallback to datasets (if available) | |
if load_dataset is None: | |
raise RuntimeError("PIQA local files not found and 'datasets' not installed.") | |
ds = load_dataset("piqa") | |
tr, va = list(ds["train"]), list(ds["validation"]) | |
# subsample | |
idx_tr = rng.choice(len(tr), size=min(subset, len(tr)), replace=False) | |
idx_va = rng.choice(len(va), size=min(max(subset // 4, 200), len(va)), replace=False) | |
def pack(rows, idxs): | |
X_text, y = [], [] | |
for k in idxs: | |
p = rows[k] | |
stem = (p.get("goal") or "").strip() | |
sol1 = (p.get("sol1") or "").strip() | |
sol2 = (p.get("sol2") or "").strip() | |
label = int(p.get("label", 0)) | |
X_text.append(f"{stem} {sol1}"); y.append(1 if label == 0 else 0) | |
X_text.append(f"{stem} {sol2}"); y.append(1 if label == 1 else 0) | |
return X_text, np.array(y, dtype=np.int64) | |
Xtr_txt, ytr = pack(tr, idx_tr) | |
Xva_txt, yva = pack(va, idx_va) | |
return Xtr_txt, ytr, Xva_txt, yva | |
# ----------------------------- | |
# HellaSwag loader (4-way expanded) | |
# ----------------------------- | |
def load_hellaswag(subset: int = 800, seed: int = 42): | |
""" | |
Returns: | |
Xtr_txt, ytr, Xva_txt, yva | |
For each example, we emit FOUR rows: | |
[context + ending_i] with label 1 if i is the correct ending else 0 | |
""" | |
rng = np.random.RandomState(seed) | |
tr_name, va_name = "hellaswag_train.jsonl", "hellaswag_valid.jsonl" | |
if _has_local(tr_name, va_name): | |
tr = _read_jsonl(os.path.join(DATA_DIR, tr_name)) | |
va = _read_jsonl(os.path.join(DATA_DIR, va_name)) | |
else: | |
if load_dataset is None: | |
raise RuntimeError("HellaSwag local files not found and 'datasets' not installed.") | |
ds = load_dataset("hellaswag") | |
tr, va = list(ds["train"]), list(ds["validation"]) | |
idx_tr = rng.choice(len(tr), size=min(subset, len(tr)), replace=False) | |
idx_va = rng.choice(len(va), size=min(max(subset // 4, 200), len(va)), replace=False) | |
def pack(rows, idxs): | |
X_text, y = [], [] | |
for k in idxs: | |
p = rows[k] | |
ctx = f"{(p.get('ctx') or '')} {(p.get('ctx_a') or '')}".strip() | |
endings = p.get("endings") or [] | |
label = int(p.get("label", 0)) | |
for i, e in enumerate(endings): | |
X_text.append(f"{ctx} {e}".strip()) | |
y.append(1 if i == label else 0) | |
return X_text, np.array(y, dtype=np.int64) | |
Xtr_txt, ytr = pack(tr, idx_tr) | |
Xva_txt, yva = pack(va, idx_va) | |
return Xtr_txt, ytr, Xva_txt, yva | |