#!/usr/bin/env python3 """ Convert UltraChat-style records (prompt, messages[], prompt_id) into a single "text" field using the chat formatting expected by training/train_gemma_unsloth.py: user\n...\nmodel\n...\n(repeat) Usage: python scripts/convert_ultrachat_to_text.py \ --in sample_data/train_sft.jsonl \ --out sample_data/train_sft_text.jsonl """ from __future__ import annotations import argparse import json import os from typing import List, TypedDict, Any, Dict, cast ROLE_MAP = { "user": "user", "assistant": "model", } class Msg(TypedDict): content: str role: str def to_chat_text(messages: List[Msg]) -> str: parts: List[str] = [] for m in messages: role = ROLE_MAP.get(m.get("role", "user"), "user") content = (m.get("content", "") or "").rstrip() parts.append(f"{role}\n{content}") return "\n".join(parts) + "\n" def convert(in_path: str, out_path: str) -> int: os.makedirs(os.path.dirname(out_path), exist_ok=True) n_in = 0 n_out = 0 with open(in_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout: for line in fin: if not line.strip(): continue n_in += 1 obj: Any = json.loads(line) raw: Any = obj.get("messages") if not isinstance(raw, list) or not raw: continue all_dicts = True for x_any in cast(List[Any], raw): if not isinstance(x_any, dict): all_dicts = False break if not all_dicts: continue raw_list: List[Dict[str, Any]] = cast(List[Dict[str, Any]], raw) messages: List[Msg] = [] for item in raw_list: content_any = item.get("content") role_any = item.get("role") if not isinstance(content_any, str) or not isinstance(role_any, str): messages = [] break messages.append({"content": content_any, "role": role_any}) if not messages: continue text = to_chat_text(messages) fout.write(json.dumps({"text": text}, ensure_ascii=False) + "\n") n_out += 1 print(f"Converted {n_out}/{n_in} records -> {out_path}") return 0 def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--in", dest="in_path", required=True) ap.add_argument("--out", dest="out_path", required=True) args = ap.parse_args() return convert(args.in_path, args.out_path) if __name__ == "__main__": raise SystemExit(main())