darabos commited on
Commit
3840cdb
·
unverified ·
2 Parent(s): 4678319 01ce750

Merge pull request #233 from biggraph/darabos-unsloth

Browse files
examples/Unsloth/Demo.lynxkite.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/Unsloth/boxes.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ from lynxkite_core import ops
3
+ from lynxkite_graph_analytics.core import Bundle, TableName, ColumnNameByTableName
4
+ import unsloth
5
+ import trl
6
+ from datasets import load_dataset, Dataset
7
+ import unsloth.chat_templates
8
+ from transformers.training_args import OptimizerNames
9
+ from transformers.trainer_utils import SchedulerType
10
+
11
+ op = ops.op_registration("LynxKite Graph Analytics", "Unsloth")
12
+
13
+
14
+ @op("Load base model", slow=True, cache=False)
15
+ def load_base_model(
16
+ *,
17
+ model_name: str,
18
+ max_seq_length: int = 2048,
19
+ load_in_4bit: bool = False,
20
+ load_in_8bit: bool = False,
21
+ full_finetuning: bool = False,
22
+ ):
23
+ model, tokenizer = unsloth.FastModel.from_pretrained(
24
+ model_name=model_name,
25
+ max_seq_length=max_seq_length,
26
+ load_in_4bit=load_in_4bit,
27
+ load_in_8bit=load_in_8bit,
28
+ full_finetuning=full_finetuning,
29
+ )
30
+ return Bundle(other={"model": model, "tokenizer": tokenizer})
31
+
32
+
33
+ @op("Configure LoRA", slow=True, cache=False)
34
+ def configure_lora(bundle: Bundle, *, r=128, lora_dropout=0, random_state=1, rank_stabilized=False):
35
+ bundle = bundle.copy()
36
+ model = bundle.other["model"]
37
+ bundle.other["model"] = unsloth.FastModel.get_peft_model(
38
+ model,
39
+ r=r,
40
+ lora_dropout=lora_dropout,
41
+ random_state=random_state,
42
+ use_rslora=rank_stabilized,
43
+ target_modules=[
44
+ "q_proj",
45
+ "k_proj",
46
+ "v_proj",
47
+ "o_proj",
48
+ "gate_proj",
49
+ "up_proj",
50
+ "down_proj",
51
+ ],
52
+ lora_alpha=128,
53
+ bias="none",
54
+ use_gradient_checkpointing="unsloth",
55
+ loftq_config=None,
56
+ )
57
+ return bundle
58
+
59
+
60
+ @op("Load HF dataset", slow=True, cache=False)
61
+ def load_hf_dataset(*, name: str, split="train[:10000]") -> Bundle:
62
+ return Bundle(dfs={"dataset": load_dataset(name, split=split).to_pandas()})
63
+
64
+
65
+ @op("Convert to ChatML", slow=True, cache=False)
66
+ def convert_to_chatml(
67
+ bundle: Bundle,
68
+ *,
69
+ table_name: TableName,
70
+ system_column_name: ColumnNameByTableName,
71
+ user_column_name: ColumnNameByTableName,
72
+ assistant_column_name: ColumnNameByTableName,
73
+ save_as: str = "conversations",
74
+ ):
75
+ bundle = bundle.copy()
76
+ ds = bundle.dfs[table_name]
77
+ bundle.dfs[table_name][save_as] = ds.apply(
78
+ lambda e: [
79
+ {"role": "system", "content": e[system_column_name]},
80
+ {"role": "user", "content": e[user_column_name]},
81
+ {"role": "assistant", "content": e[assistant_column_name]},
82
+ ],
83
+ axis=1,
84
+ )
85
+ return bundle
86
+
87
+
88
+ @op("Apply chat template", slow=True, cache=False)
89
+ def apply_chat_template(
90
+ bundle: Bundle,
91
+ *,
92
+ table_name: TableName,
93
+ conversations_field: ColumnNameByTableName,
94
+ save_as="text",
95
+ ):
96
+ bundle = bundle.copy()
97
+ tokenizer = bundle.other["tokenizer"]
98
+ bundle.dfs[table_name][save_as] = bundle.dfs[table_name][conversations_field].map(
99
+ lambda e: tokenizer.apply_chat_template(
100
+ e, tokenize=False, add_generation_prompt=False
101
+ ).removeprefix("<bos>"),
102
+ )
103
+ return bundle
104
+
105
+
106
+ @op("Train LLM", slow=True, cache=False)
107
+ def train_llm(
108
+ bundle: Bundle,
109
+ *,
110
+ table_name: TableName,
111
+ dataset_text_field: ColumnNameByTableName,
112
+ train_on_responses_only=True,
113
+ per_device_train_batch_size=8,
114
+ gradient_accumulation_steps=1,
115
+ warmup_steps=5,
116
+ num_train_epochs: int | None = 1,
117
+ max_steps: int | None = None,
118
+ learning_rate=5e-5,
119
+ logging_steps=1,
120
+ optim=OptimizerNames.ADAMW_8BIT,
121
+ weight_decay=0.01,
122
+ lr_scheduler_type=SchedulerType.LINEAR,
123
+ seed=1,
124
+ ):
125
+ model = bundle.other["model"]
126
+ tokenizer = bundle.other["tokenizer"]
127
+ dataset = Dataset.from_pandas(bundle.dfs[table_name])
128
+ trainer = trl.SFTTrainer(
129
+ model=model,
130
+ tokenizer=tokenizer,
131
+ train_dataset=dataset,
132
+ eval_dataset=None,
133
+ args=trl.SFTConfig(
134
+ dataset_text_field=dataset_text_field,
135
+ per_device_train_batch_size=per_device_train_batch_size,
136
+ gradient_accumulation_steps=gradient_accumulation_steps,
137
+ warmup_steps=warmup_steps,
138
+ num_train_epochs=num_train_epochs or -1,
139
+ max_steps=max_steps or -1,
140
+ learning_rate=learning_rate,
141
+ logging_steps=logging_steps,
142
+ optim=optim,
143
+ weight_decay=weight_decay,
144
+ lr_scheduler_type=lr_scheduler_type,
145
+ seed=seed,
146
+ output_dir="outputs",
147
+ report_to="none",
148
+ ),
149
+ )
150
+ if train_on_responses_only:
151
+ trainer = unsloth.chat_templates.train_on_responses_only(
152
+ trainer,
153
+ instruction_part="<start_of_turn>user\n",
154
+ response_part="<start_of_turn>model\n",
155
+ )
156
+ trainer_stats = trainer.train()
157
+ bundle = bundle.copy()
158
+ bundle.other["trainer_stats"] = trainer_stats
159
+ return bundle
160
+
161
+
162
+ @op("Save model (LoRA only)", outputs=[], slow=True, cache=False)
163
+ def save_model_lora(bundle: Bundle, *, file_name: str):
164
+ model = bundle.other["model"]
165
+ tokenizer = bundle.other["tokenizer"]
166
+ model.save_pretrained(file_name)
167
+ tokenizer.save_pretrained(file_name)
168
+
169
+
170
+ @op("Save model (float16)", outputs=[], slow=True, cache=False)
171
+ def save_model_float16(bundle: Bundle, *, file_name: str):
172
+ model = bundle.other["model"]
173
+ tokenizer = bundle.other["tokenizer"]
174
+ model.save_pretrained_merged(file_name, tokenizer, save_method="merged_16bit")
175
+
176
+
177
+ @op("Save model (int4)", outputs=[], slow=True, cache=False)
178
+ def save_model_int4(bundle: Bundle, *, file_name: str):
179
+ model = bundle.other["model"]
180
+ tokenizer = bundle.other["tokenizer"]
181
+ model.save_pretrained_merged(file_name, tokenizer, save_method="merged_4bit")
182
+
183
+
184
+ class QuantizationType(enum.StrEnum):
185
+ Q8_0 = "Q8_0"
186
+ BF16 = "BF16"
187
+ F16 = "F16"
188
+
189
+
190
+ @op("Save model (GGUF)", outputs=[], slow=True, cache=False)
191
+ def save_model_gguf(
192
+ bundle: Bundle, *, file_name: str, quantization: QuantizationType = QuantizationType.Q8_0
193
+ ):
194
+ model = bundle.other["model"]
195
+ tokenizer = bundle.other["tokenizer"]
196
+ model.save_pretrained_gguf(
197
+ file_name,
198
+ tokenizer,
199
+ quantization_type=quantization.value,
200
+ )
201
+
202
+
203
+ @op("Chat with model", view="service")
204
+ def chat_with_model(bundle: Bundle):
205
+ # TODO: Implement this.
206
+ pass
examples/Unsloth/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ datasets
2
+ unsloth