ShaNet / collect.py
umm-dev's picture
Upload 8 files (#1)
336661a verified
#!/usr/bin/env python3
"""
Download, transform LMSYS-Chat-1M into plain text for LLM completion models
in the format:
<|im_start|>role
message<|endoftext|>
<|im_stop|>
with 6 newlines between conversations.
"""
from datasets import load_dataset
import sys
def main(output_path="lmsys_chat_1m.txt", split="train"):
ds = load_dataset("lmsys/lmsys-chat-1m", split=split)
with open(output_path, "w", encoding="utf-8") as out:
for i, sample in enumerate(ds):
conv = sample["conversation"] # list of messages
for msg in conv:
role = msg["role"]
content = msg["content"].strip()
out.write(f"<|im_start|>{role}\n{content}<|endoftext|>\n<|im_stop|>\n")
out.write("\n" * 6) # 6 newlines between conversations
if (i + 1) % 10000 == 0:
print(f"Processed {i + 1} conversations", file=sys.stderr)
print(f"✔ Saved plain-text to: {output_path}")
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser(description="Convert LMSYS-Chat-1M to LLM-friendly text format")
p.add_argument("--output", "-o", default="lmsys_chat_1m.txt", help="Output file path")
p.add_argument("--split", "-s", default="train", help="Dataset split (e.g. 'train')")
args = p.parse_args()
main(output_path=args.output, split=args.split)