HF_Agents_Final_Project / utilities /common_questions.py
Yago Bolivar
feat: implement functions to load GAIA questions and validation records
df6ca23
import json, sys
from pathlib import Path
from typing import List, Dict, Any
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_gaia_questions(path: Path) -> set[str]:
"""Return *set of question strings* from GAIA file (regular JSON array)."""
with path.open("r", encoding="utf-8") as f:
data: List[Dict[str, Any]] = json.load(f)
return {rec.get("question", "").strip() for rec in data if rec.get("question")}
def load_validation_records(path: Path) -> List[Dict[str, Any]]:
"""Return *full dict records* from validation.json (newline‑delimited JSON)."""
records: List[Dict[str, Any]] = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
records.append(obj)
except json.JSONDecodeError as e:
# Skip malformed line but warn user
print(f"Warning: could not parse line in {path}: {e}")
return records
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main(
gaia_path: str = "question_set/gaia_questions.json",
validation_path: str = "question_set/validation.json",
output_path: str = "question_set/common_questions.json",
) -> None:
gaia_file = Path(gaia_path)
val_file = Path(validation_path)
if not gaia_file.exists():
sys.exit(f"Error: {gaia_file} not found")
if not val_file.exists():
sys.exit(f"Error: {val_file} not found")
gaia_questions = load_gaia_questions(gaia_file)
validation_recs = load_validation_records(val_file)
# Keep validation records whose question text (case insensitive on key) is in GAIA set
common: List[Dict[str, Any]] = []
for rec in validation_recs:
q = rec.get("question") or rec.get("Question")
if q and q.strip() in gaia_questions:
common.append(rec)
# Ensure output directory exists
out_path = Path(output_path)
out_path.parent.mkdir(parents=True, exist_ok=True)
# Write out as a JSON array (same style as GAIA file)
with out_path.open("w", encoding="utf-8") as f:
json.dump(common, f, indent=2, ensure_ascii=False)
print(f"Extracted {len(common)} record(s) – full validation fields kept –> {out_path}")
if __name__ == "__main__":
# Allow command‑line overrides: python script.py [gaia] [validation] [output]
main(*sys.argv[1:])