File size: 2,775 Bytes
7ecd130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
"""
Convert UltraChat-style records (prompt, messages[], prompt_id) into a single
"text" field using the chat formatting expected by training/train_gemma_unsloth.py:

<start_of_turn>user\n...<end_of_turn>\n<start_of_turn>model\n...<end_of_turn>\n(repeat)

Usage:
  python scripts/convert_ultrachat_to_text.py \
      --in sample_data/train_sft.jsonl \
      --out sample_data/train_sft_text.jsonl
"""
from __future__ import annotations

import argparse
import json
import os
from typing import List, TypedDict, Any, Dict, cast

ROLE_MAP = {
    "user": "user",
    "assistant": "model",
}


class Msg(TypedDict):
    content: str
    role: str


def to_chat_text(messages: List[Msg]) -> str:
    parts: List[str] = []
    for m in messages:
        role = ROLE_MAP.get(m.get("role", "user"), "user")
        content = (m.get("content", "") or "").rstrip()
        parts.append(f"<start_of_turn>{role}\n{content}<end_of_turn>")
    return "\n".join(parts) + "\n"


def convert(in_path: str, out_path: str) -> int:
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    n_in = 0
    n_out = 0
    with open(in_path, "r", encoding="utf-8") as fin, open(out_path, "w", encoding="utf-8") as fout:
        for line in fin:
            if not line.strip():
                continue
            n_in += 1
            obj: Any = json.loads(line)
            raw: Any = obj.get("messages")
            if not isinstance(raw, list) or not raw:
                continue
            all_dicts = True
            for x_any in cast(List[Any], raw):
                if not isinstance(x_any, dict):
                    all_dicts = False
                    break
            if not all_dicts:
                continue
            raw_list: List[Dict[str, Any]] = cast(List[Dict[str, Any]], raw)
            messages: List[Msg] = []
            for item in raw_list:
                content_any = item.get("content")
                role_any = item.get("role")
                if not isinstance(content_any, str) or not isinstance(role_any, str):
                    messages = []
                    break
                messages.append({"content": content_any, "role": role_any})
            if not messages:
                continue
            text = to_chat_text(messages)
            fout.write(json.dumps({"text": text}, ensure_ascii=False) + "\n")
            n_out += 1
    print(f"Converted {n_out}/{n_in} records -> {out_path}")
    return 0


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--in", dest="in_path", required=True)
    ap.add_argument("--out", dest="out_path", required=True)
    args = ap.parse_args()
    return convert(args.in_path, args.out_path)


if __name__ == "__main__":
    raise SystemExit(main())