|
|
|
""" |
|
Convert UltraChat-style records (prompt, messages[], prompt_id) into a single |
|
"text" field using the chat formatting expected by training/train_gemma_unsloth.py: |
|
|
|
<start_of_turn>user\n...<end_of_turn>\n<start_of_turn>model\n...<end_of_turn>\n(repeat) |
|
|
|
Usage: |
|
python scripts/convert_ultrachat_to_text.py \ |
|
--in sample_data/train_sft.jsonl \ |
|
--out sample_data/train_sft_text.jsonl |
|
""" |
|
from __future__ import annotations |
|
|
|
import argparse |
|
import json |
|
import os |
|
from typing import List, TypedDict, Any, Dict, cast |
|
|
|
ROLE_MAP = { |
|
"user": "user", |
|
"assistant": "model", |
|
} |
|
|
|
|
|
class Msg(TypedDict): |
|
content: str |
|
role: str |
|
|
|
|
|
def to_chat_text(messages: List[Msg]) -> str: |
|
parts: List[str] = [] |
|
for m in messages: |
|
role = ROLE_MAP.get(m.get("role", "user"), "user") |
|
content = (m.get("content", "") or "").rstrip() |
|
parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>") |
|
return "\n".join(parts) + "\n" |
|
|
|
|
|
def convert(in_path: str, out_path: str) -> int: |
|
os.makedirs(os.path.dirname(out_path), exist_ok=True) |
|
n_in = 0 |
|
n_out = 0 |
|
with open(in_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout: |
|
for line in fin: |
|
if not line.strip(): |
|
continue |
|
n_in += 1 |
|
obj: Any = json.loads(line) |
|
raw: Any = obj.get("messages") |
|
if not isinstance(raw, list) or not raw: |
|
continue |
|
all_dicts = True |
|
for x_any in cast(List[Any], raw): |
|
if not isinstance(x_any, dict): |
|
all_dicts = False |
|
break |
|
if not all_dicts: |
|
continue |
|
raw_list: List[Dict[str, Any]] = cast(List[Dict[str, Any]], raw) |
|
messages: List[Msg] = [] |
|
for item in raw_list: |
|
content_any = item.get("content") |
|
role_any = item.get("role") |
|
if not isinstance(content_any, str) or not isinstance(role_any, str): |
|
messages = [] |
|
break |
|
messages.append({"content": content_any, "role": role_any}) |
|
if not messages: |
|
continue |
|
text = to_chat_text(messages) |
|
fout.write(json.dumps({"text": text}, ensure_ascii=False) + "\n") |
|
n_out += 1 |
|
print(f"Converted {n_out}/{n_in} records -> {out_path}") |
|
return 0 |
|
|
|
|
|
def main() -> int: |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("--in", dest="in_path", required=True) |
|
ap.add_argument("--out", dest="out_path", required=True) |
|
args = ap.parse_args() |
|
return convert(args.in_path, args.out_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
raise SystemExit(main()) |
|
|