firstAI / scripts /convert_ultrachat_to_text.py
ndc8
Add scripts for converting and generating UltraChat-style SFT dataset
7ecd130
#!/usr/bin/env python3
"""
Convert UltraChat-style records (prompt, messages[], prompt_id) into a single
"text" field using the chat formatting expected by training/train_gemma_unsloth.py:
<start_of_turn>user\n...<end_of_turn>\n<start_of_turn>model\n...<end_of_turn>\n(repeat)
Usage:
python scripts/convert_ultrachat_to_text.py \
--in sample_data/train_sft.jsonl \
--out sample_data/train_sft_text.jsonl
"""
from __future__ import annotations
import argparse
import json
import os
from typing import List, TypedDict, Any, Dict, cast
ROLE_MAP = {
"user": "user",
"assistant": "model",
}
class Msg(TypedDict):
content: str
role: str
def to_chat_text(messages: List[Msg]) -> str:
parts: List[str] = []
for m in messages:
role = ROLE_MAP.get(m.get("role", "user"), "user")
content = (m.get("content", "") or "").rstrip()
parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>")
return "\n".join(parts) + "\n"
def convert(in_path: str, out_path: str) -> int:
os.makedirs(os.path.dirname(out_path), exist_ok=True)
n_in = 0
n_out = 0
with open(in_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout:
for line in fin:
if not line.strip():
continue
n_in += 1
obj: Any = json.loads(line)
raw: Any = obj.get("messages")
if not isinstance(raw, list) or not raw:
continue
all_dicts = True
for x_any in cast(List[Any], raw):
if not isinstance(x_any, dict):
all_dicts = False
break
if not all_dicts:
continue
raw_list: List[Dict[str, Any]] = cast(List[Dict[str, Any]], raw)
messages: List[Msg] = []
for item in raw_list:
content_any = item.get("content")
role_any = item.get("role")
if not isinstance(content_any, str) or not isinstance(role_any, str):
messages = []
break
messages.append({"content": content_any, "role": role_any})
if not messages:
continue
text = to_chat_text(messages)
fout.write(json.dumps({"text": text}, ensure_ascii=False) + "\n")
n_out += 1
print(f"Converted {n_out}/{n_in} records -> {out_path}")
return 0
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--in", dest="in_path", required=True)
ap.add_argument("--out", dest="out_path", required=True)
args = ap.parse_args()
return convert(args.in_path, args.out_path)
if __name__ == "__main__":
raise SystemExit(main())