jefalod commited on
Commit
3eb42eb
·
verified ·
1 Parent(s): 61f1efa

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +92 -0
train.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import torch
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoTokenizer, AutoModelForCausalLM, TrainingArguments,
7
+ Trainer, BitsAndBytesConfig
8
+ )
9
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
10
+
11
+ model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
+
13
+ # Load dataset from URL
14
+ dataset = load_dataset(
15
+ "json",
16
+ data_files="https://huggingface.co/datasets/bitext/Bitext-customer-support-llm-chatbot-training-dataset/resolve/main/bitext_customer_support.jsonl",
17
+ split="train[:100]" # limit for fast training in Spaces
18
+ )
19
+
20
+ def format_example(example):
21
+ return {
22
+ "text": f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['output']}"
23
+ }
24
+
25
+ dataset = dataset.map(format_example)
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
28
+ tokenizer.pad_token = tokenizer.eos_token
29
+
30
+ def tokenize(example):
31
+ tokens = tokenizer(
32
+ example["text"],
33
+ padding="max_length",
34
+ truncation=True,
35
+ max_length=512
36
+ )
37
+ tokens["labels"] = tokens["input_ids"].copy()
38
+ return tokens
39
+
40
+ tokenized_dataset = dataset.map(tokenize, batched=True)
41
+
42
+ bnb_config = BitsAndBytesConfig(
43
+ load_in_4bit=True,
44
+ bnb_4bit_use_double_quant=True,
45
+ bnb_4bit_quant_type="nf4",
46
+ bnb_4bit_compute_dtype=torch.float16
47
+ )
48
+
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ model_name,
51
+ quantization_config=bnb_config,
52
+ device_map="auto"
53
+ )
54
+
55
+ model.gradient_checkpointing_enable()
56
+ model = prepare_model_for_kbit_training(model)
57
+
58
+ lora_config = LoraConfig(
59
+ r=8,
60
+ lora_alpha=32,
61
+ lora_dropout=0.05,
62
+ bias="none",
63
+ task_type="CAUSAL_LM",
64
+ target_modules=["q_proj", "v_proj"]
65
+ )
66
+
67
+ model = get_peft_model(model, lora_config)
68
+
69
+ training_args = TrainingArguments(
70
+ output_dir="trained_model",
71
+ per_device_train_batch_size=2,
72
+ gradient_accumulation_steps=4,
73
+ num_train_epochs=1,
74
+ learning_rate=2e-4,
75
+ logging_dir="./logs",
76
+ save_strategy="no",
77
+ bf16=True,
78
+ optim="paged_adamw_8bit",
79
+ report_to="none"
80
+ )
81
+
82
+ trainer = Trainer(
83
+ model=model,
84
+ args=training_args,
85
+ train_dataset=tokenized_dataset,
86
+ tokenizer=tokenizer
87
+ )
88
+
89
+ trainer.train()
90
+
91
+ model.save_pretrained("trained_model")
92
+ tokenizer.save_pretrained("trained_model")