bigPear commited on
Commit
7975f51
·
1 Parent(s): 6f1c658

Upload 76 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. LICENSE +201 -0
  3. README_zh.md +339 -0
  4. assets/trainer_state.jpg +0 -0
  5. assets/wechat.jpg +0 -0
  6. data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json +3 -0
  7. data/.ipynb_checkpoints/dataset_info-checkpoint.json +113 -0
  8. data/.ipynb_checkpoints/self_cognition-checkpoint.json +92 -0
  9. data/README.md +53 -0
  10. data/alpaca_data_en_52k.json +3 -0
  11. data/alpaca_data_zh_51k.json +3 -0
  12. data/alpaca_gpt4_data_en.json +3 -0
  13. data/alpaca_gpt4_data_zh.json +3 -0
  14. data/belle_multiturn/belle_multiturn.py +70 -0
  15. data/comparison_gpt4_data_en.json +3 -0
  16. data/comparison_gpt4_data_zh.json +3 -0
  17. data/dataset_info.json +113 -0
  18. data/example_dataset/example_dataset.py +46 -0
  19. data/example_dataset/examples.json +20 -0
  20. data/hh_rlhf_en/hh_rlhf_en.py +97 -0
  21. data/self_cognition.json +92 -0
  22. data/ultra_chat/ultra_chat.py +76 -0
  23. data/wdf.json +0 -0
  24. examples/ads_generation.md +172 -0
  25. examples/alter_self_cognition.md +104 -0
  26. examples/cloudflare-tunnel-to-colab.ipynb +143 -0
  27. examples/evaluate.sh +12 -0
  28. examples/finetune.sh +17 -0
  29. examples/media/ads_generation_1.jpg +0 -0
  30. examples/media/ads_generation_2.jpg +0 -0
  31. examples/media/ads_generation_3.jpg +0 -0
  32. examples/media/ads_generation_4.jpg +0 -0
  33. examples/media/ads_generation_5.jpg +0 -0
  34. examples/media/ads_generation_6.jpg +0 -0
  35. examples/media/alter_self_cognition_1.jpg +0 -0
  36. examples/media/alter_self_cognition_2.jpg +0 -0
  37. examples/media/alter_self_cognition_3.jpg +0 -0
  38. examples/media/alter_self_cognition_4.jpg +0 -0
  39. examples/quantized_finetune_with_local_model.sh +19 -0
  40. examples/train_ppo.sh +19 -0
  41. examples/train_rm.sh +17 -0
  42. requirements.txt +14 -0
  43. src/.ipynb_checkpoints/__init__-checkpoint.py +4 -0
  44. src/.ipynb_checkpoints/infer-checkpoint.py +77 -0
  45. src/__init__.py +4 -0
  46. src/finetune.py +88 -0
  47. src/infer.py +77 -0
  48. src/train_ppo.py +85 -0
  49. src/train_rm.py +60 -0
  50. src/utils/.ipynb_checkpoints/__init__-checkpoint.py +26 -0
.gitattributes CHANGED
@@ -32,3 +32,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json filter=lfs diff=lfs merge=lfs -text
36
+ data/alpaca_data_en_52k.json filter=lfs diff=lfs merge=lfs -text
37
+ data/alpaca_data_zh_51k.json filter=lfs diff=lfs merge=lfs -text
38
+ data/alpaca_gpt4_data_en.json filter=lfs diff=lfs merge=lfs -text
39
+ data/alpaca_gpt4_data_zh.json filter=lfs diff=lfs merge=lfs -text
40
+ data/comparison_gpt4_data_en.json filter=lfs diff=lfs merge=lfs -text
41
+ data/comparison_gpt4_data_zh.json filter=lfs diff=lfs merge=lfs -text
42
+ tests/comparison_data_v2.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README_zh.md ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatGLM Efficient Tuning
2
+
3
+ ![GitHub Repo stars](https://img.shields.io/github/stars/hiyouga/ChatGLM-Efficient-Tuning?style=social)
4
+ ![GitHub Code License](https://img.shields.io/github/license/hiyouga/ChatGLM-Efficient-Tuning)
5
+ ![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/ChatGLM-Efficient-Tuning)
6
+ ![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)
7
+
8
+ 基于 🤗[PEFT](https://github.com/huggingface/peft) 的高效 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 微调。
9
+
10
+ 👋 加入我们的[微信群](assets/wechat.jpg)。
11
+
12
+ \[ [English](README.md) | 中文 \]
13
+
14
+ ## 更新日志
15
+
16
+ [23/04/29] 现在我们实现了 **RLHF(基于人类反馈的强化学习)** 训练!我们提供了几个运行 RLHF 的例子,具体内容请移步 `examples` 文件夹。(实验性功能)
17
+
18
+ [23/04/25] 我们新增了一个使用自定义数据集分布式训练的例子,请移步 [ads_generation.md](examples/ads_generation.md) 查阅。
19
+
20
+ [23/04/20] 我们的项目在 12 天内获得了 100 个 Star!祝贺!
21
+
22
+ [23/04/20] 我们新增了一个修改模型自我认知的例子,请移步 [alter_self_cognition.md](examples/alter_self_cognition.md) 查阅。
23
+
24
+ [23/04/19] 现在我们实现了**模型融合**!请尝试使用 `--checkpoint_dir checkpoint1,checkpoint2` 参数训练融合 LoRA 权重后的模型。
25
+
26
+ [23/04/18] 现在可以微调**量化模型**了!请尝试使用 `quantization_bit` 参数进行 4 比特或 8 比特量化微调。
27
+
28
+ [23/04/12] 现在我们加入了**断点训练支持**!请尝试给定 `--checkpoint_dir` 参数加载指定的模型断点。
29
+
30
+ [23/04/11] 现在我们实现了**数据集组合训练**!请尝试使用 `--dataset dataset1,dataset2` 参数进行组合训练。
31
+
32
+ ## 数据集
33
+
34
+ 目前我们实现了针对以下数据集的支持:
35
+
36
+ - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
37
+ - [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
38
+ - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
39
+ - [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
40
+ - [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
41
+ - [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
42
+ - [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
43
+ - [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
44
+ - [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
45
+ - [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
46
+ - [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
47
+ - [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
48
+ - [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
49
+ - [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
50
+ - [UltraChat](https://github.com/thunlp/UltraChat)
51
+
52
+ 使用方法请参考 [data/README.md](data/README.md) 文件。
53
+
54
+ 部分数据集的使用需要确认,我们推荐使用下述命令登录您的 HuggingFace 账户。
55
+
56
+ ```bash
57
+ pip install --upgrade huggingface_hub
58
+ huggingface-cli login
59
+ ```
60
+
61
+ ## 微调方法
62
+
63
+ 目前我们实现了针对以下高效微调方法的支持:
64
+
65
+ - [LoRA](https://arxiv.org/abs/2106.09685)
66
+ - 仅微调低秩适应器。
67
+ - [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
68
+ - 仅微调前缀编码器。
69
+ - [Freeze](https://arxiv.org/abs/2012.14913)
70
+ - 仅微调后几层的全连接层。
71
+
72
+ ## 软件依赖
73
+
74
+ - Python 3.8+, PyTorch 1.13.1
75
+ - 🤗Transformers, Datasets, Accelerate, PEFT, TRL
76
+ - protobuf, cpm_kernels, sentencepiece
77
+ - jieba, rouge_chinese, nltk(用于评估)
78
+ - gradio, mdtex2html(用于网页端交互)
79
+
80
+ 以及 **强而有力的 GPU**!
81
+
82
+ ## 如何使用
83
+
84
+ ### 数据准备(可跳过)
85
+
86
+ 关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
87
+
88
+ 注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
89
+
90
+ ### 环境搭建(可跳过)
91
+
92
+ ```bash
93
+ git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
94
+ conda create -n chatglm_etuning python=3.10
95
+ conda activate chatglm_etuning
96
+ cd ChatGLM-Efficient-Tuning
97
+ pip install -r requirements.txt
98
+ ```
99
+
100
+ 对于 Windows 用户,若要启用 LoRA 或 Freeze 的量化微调,请下载预构建的 `bitsandbytes` 包,目前仅支持 CUDA 11.6 和 11.7。
101
+
102
+ ```
103
+ pip install https://github.com/acpopescu/bitsandbytes/releases/download/v0.37.2-win.1/bitsandbytes-0.37.2-py3-none-any.whl
104
+ ```
105
+
106
+ ### 单 GPU 微调训练
107
+
108
+ ```bash
109
+ CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
110
+ --do_train \
111
+ --dataset alpaca_gpt4_zh \
112
+ --finetuning_type lora \
113
+ --output_dir path_to_sft_checkpoint \
114
+ --per_device_train_batch_size 4 \
115
+ --gradient_accumulation_steps 4 \
116
+ --lr_scheduler_type cosine \
117
+ --logging_steps 10 \
118
+ --save_steps 1000 \
119
+ --learning_rate 5e-5 \
120
+ --num_train_epochs 3.0 \
121
+ --fp16
122
+ ```
123
+
124
+ 关于参数信息,请查阅我们的[维基](https://github.com/hiyouga/ChatGLM-Efficient-Tuning/wiki)。
125
+
126
+ ### 多 GPU 分布式微调
127
+
128
+ ```bash
129
+ accelerate config # 首先配置分布式环境
130
+ accelerate launch src/finetune.py # 参数同上
131
+ ```
132
+
133
+ 注意:若您使用 LoRA 方法进行微调,请指定以下参数 `--ddp_find_unused_parameters False` 来避免报错。
134
+
135
+ ### 奖励模型训练
136
+
137
+ ```bash
138
+ CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
139
+ --do_train \
140
+ --dataset comparison_gpt4_zh \
141
+ --finetuning_type lora \
142
+ --output_dir path_to_rm_checkpoint \
143
+ --per_device_train_batch_size 4 \
144
+ --gradient_accumulation_steps 4 \
145
+ --lr_scheduler_type cosine \
146
+ --logging_steps 10 \
147
+ --save_steps 1000 \
148
+ --learning_rate 1e-5 \
149
+ --num_train_epochs 1.0 \
150
+ --fp16
151
+ ```
152
+
153
+ ### RLHF 训练
154
+
155
+ ```bash
156
+ CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
157
+ --do_train \
158
+ --dataset alpaca_gpt4_zh \
159
+ --finetuning_type lora \
160
+ --checkpoint_dir path_to_sft_checkpoint \
161
+ --reward_model path_to_rm_checkpoint \
162
+ --output_dir path_to_ppo_checkpoint \
163
+ --per_device_train_batch_size 2 \
164
+ --gradient_accumulation_steps 4 \
165
+ --lr_scheduler_type cosine \
166
+ --logging_steps 10 \
167
+ --save_steps 1000 \
168
+ --learning_rate 1e-5 \
169
+ --num_train_epochs 1.0 \
170
+ --fp16
171
+ ```
172
+
173
+ ### 指标评估(BLEU分数和汉语ROUGE分数)
174
+
175
+ ```bash
176
+ CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
177
+ --do_eval \
178
+ --dataset alpaca_gpt4_zh \
179
+ --checkpoint_dir path_to_checkpoint \
180
+ --output_dir path_to_eval_result \
181
+ --per_device_eval_batch_size 8 \
182
+ --max_samples 50 \
183
+ --predict_with_generate
184
+ ```
185
+
186
+ ### 模型预测
187
+ ```bash
188
+ CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
189
+ --do_predict \
190
+ --dataset alpaca_gpt4_zh \
191
+ --checkpoint_dir path_to_checkpoint \
192
+ --output_dir path_to_predict_result \
193
+ --per_device_eval_batch_size 8 \
194
+ --max_samples 50 \
195
+ --predict_with_generate
196
+ ```
197
+
198
+ ### 命令行测试
199
+
200
+ ```bash
201
+ CUDA_VISIBLE_DEVICES=0 python src/infer.py \
202
+ --checkpoint_dir path_to_checkpoint
203
+ ```
204
+
205
+ ### 浏览器测试
206
+
207
+ ```bash
208
+ CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
209
+ --checkpoint_dir path_to_checkpoint
210
+ ```
211
+
212
+ ### 模型部署
213
+
214
+ ```python
215
+ import sys
216
+ sys.path.append("src")
217
+ from src import load_pretrained, ModelArguments
218
+ model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
219
+ model, tokenizer = load_pretrained(model_args)
220
+ model = model.cuda()
221
+ model.eval()
222
+ # model.generate, model.chat()...
223
+ ```
224
+
225
+ ### 硬件需求
226
+
227
+ | 微调方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
228
+ | ---------------- | ---------- | ---- | ------ | ----- |
229
+ | LoRA (r=8) | 16 | FP16 | 28GB | 8ex/s |
230
+ | LoRA (r=8) | 8 | FP16 | 24GB | 8ex/s |
231
+ | LoRA (r=8) | 4 | FP16 | 20GB | 8ex/s |
232
+ | LoRA (r=8) | 4 | INT8 | 10GB | 8ex/s |
233
+ | P-Tuning (p=16) | 4 | FP16 | 20GB | 8ex/s |
234
+ | P-Tuning (p=16) | 4 | INT8 | 16GB | 8ex/s |
235
+ | P-Tuning (p=16) | 4 | INT4 | 12GB | 8ex/s |
236
+ | Freeze (l=3) | 4 | FP16 | 24GB | 8ex/s |
237
+ | Freeze (l=3) | 4 | INT8 | 12GB | 8ex/s |
238
+
239
+ | 奖励模型训练方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
240
+ | --------------- | ---------- | ---- | ------ | ---- |
241
+ | LoRA (r=8) + rm | 4 | FP16 | 22GB | - |
242
+ | LoRA (r=8) + rm | 1 | INT8 | 11GB | - |
243
+
244
+ | RLHF 训练方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
245
+ | ---------------- | ---------- | ---- | ------ | ---- |
246
+ | LoRA (r=8) + ppo | 4 | FP16 | 23GB | - |
247
+ | LoRA (r=8) + ppo | 1 | INT8 | 12GB | - |
248
+
249
+ > 注:`r` 为LoRA 维数大小,`p` 为前缀词表大小,`l` 为微调层数,`ex/s` 为每秒训练的样本数。`gradient_accumulation_steps` 参数设置为 `1`。上述结果均来自于单个 Tesla V100 GPU,仅供参考。
250
+
251
+ ## 微调 ChatGLM 的例子
252
+
253
+ ### 训练结果
254
+
255
+ 我们使用整个 `alpaca_gpt4_zh` 数据集微调 ChatGLM 模型,使用秩为 8 的 LoRA 方法,使用默认超参数进行单轮训练。下图为训练损失变化曲线。
256
+
257
+ ![训练损失](assets/trainer_state.jpg)
258
+
259
+ ### 评估结果
260
+
261
+ 我们选择 `alpaca_gpt4_zh` 数据集中的前一百条数据来评估微调后的 ChatGLM 模型,并计算 BLEU 和中文 ROUGE 分数。下表为评估结果。
262
+
263
+ | 分数 | 原版模型 | FZ (l=2) | PT (p=16) | LoRA (r=8) |
264
+ | ------- | -------- | ----- | ----- | ----------------- |
265
+ | BLEU-4 | 15.75 | 16.85 | 16.06 | 17.01 (**+1.26**) |
266
+ | Rouge-1 | 34.51 | 36.62 | 34.80 | 36.77 (**+2.26**) |
267
+ | Rouge-2 | 15.11 | 17.04 | 15.32 | 16.83 (**+1.72**) |
268
+ | Rouge-l | 26.18 | 28.17 | 26.35 | 28.86 (**+2.68**) |
269
+ | 训练参数 | / | 4.35% | 0.06% | 0.06% |
270
+
271
+ > FZ:Freeze 微调,PT:P-Tuning V2 微调(为了与 LoRA 公平比较,我们使用了 `pre_seq_len=16`),训练参数:可训练参数占全部参数的百分比。
272
+
273
+ ## 和现有类似项目的比较
274
+
275
+ - [THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)
276
+ - ChatGLM 基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 微调的官方实现,使用了 [ADGEN](https://aclanthology.org/D19-1321.pdf) 数据集。
277
+ - 本仓库的代码实现绝大部分参考该项目。我们进一步实现了 [LoRA](https://arxiv.org/abs/2106.09685) 微调方法。此外,我们**动态地**将每个批处理数据中的序列进行填充,而非将其填充到模型的最大长度,此改进可以加速模型训练。
278
+ - [mymusise/ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning)
279
+ - ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
280
+ - 我们借鉴了该项目的一些想法。我们的训练脚本将数据预处理部分**集成**至训练脚本中,以避免事先生成预处理后的数据。
281
+ - [ssbuild/chatglm_finetuning](https://github.com/ssbuild/chatglm_finetuning)
282
+ - ChatGLM 基于多种微调方法的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
283
+ - 我们的训练脚本**全部**基于 [Huggingface transformers](https://github.com/huggingface/transformers) 框架实现,不依赖于额外的 [deep_training](https://github.com/ssbuild/deep_training) 框架。
284
+ - [lich99/ChatGLM-finetune-LoRA](https://github.com/lich99/ChatGLM-finetune-LoRA)
285
+ - ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
286
+ - 我们利用 [Huggingface PEFT](https://github.com/huggingface/peft) 框架来引入最先进的微调方法。
287
+ - [liucongg/ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning)
288
+ - ChatGLM 基于参数冻结、LoRA 和 P-Tuning 微调的非官方实现,使用了汽车工业数据集。
289
+ - 我们旨在引入更多指令遵循数据集用于微调 ChatGLM 模型。
290
+ - [yanqiangmiffy/InstructGLM](https://github.com/yanqiangmiffy/InstructGLM)
291
+ - ChatGLM 微调的非官方实现,旨在探索 ChatGLM 在指令遵循数据集上的潜力。
292
+ - 我们将数据预处理部分集成到训练脚本中。
293
+
294
+ ## TODO
295
+
296
+ - [ ] 利用 [LangChain](https://github.com/hwchase17/langchain) 实现能够利用外部知识的基于 ChatGLM 微调模型应用的轻松构建。
297
+ - [ ] 实现对齐算法使模型对齐人类意图。
298
+ - [x] [RLHF](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)
299
+ - [ ] [RRHF](https://github.com/GanjinZero/RRHF)
300
+ - [ ] [RAFT](https://github.com/OptimalScale/LMFlow)
301
+ - [ ] 加入更多[中文数据集](https://github.com/brightmart/nlp_chinese_corpus)。
302
+ - [x] [BELLE](https://github.com/LianjiaTech/BELLE)
303
+ - [ ] [pCLUE](https://github.com/CLUEbenchmark/pCLUE)
304
+ - [ ] [CLUECorpus](https://github.com/CLUEbenchmark/CLUECorpus2020)
305
+ - [x] [GuanacoDataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
306
+ - [x] [FireflyDataset](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
307
+ - [ ] 加入基于 [ChatGPT](https://openai.com/blog/chatgpt) 和 [GPT-4](https://openai.com/research/gpt-4) 产生的数据集。
308
+ - [ ] [Baize](https://github.com/project-baize/baize-chatbot)
309
+ - [x] [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
310
+ - [x] 实现参数冻结和 P-Tuning 微调方法。
311
+ - [x] 支持多GPU训练。(但尚不支持 LoRA 方法)
312
+ - [x] 加入模型评估脚本。(但它可能很慢!增大批处理大小可以显著提升速度)
313
+ - [x] 断点加载。
314
+ - [x] 量化微调。
315
+ - [x] 撰写基于该框架的 ChatGLM 模型微调指南手册。
316
+ - [ ] 结合模型编辑技术。(例如:[MEND](https://arxiv.org/abs/2110.11309))
317
+ - [ ] 加入 [OpenAssistant 对话数据集](https://huggingface.co/datasets/OpenAssistant/oasst1)用于监督微调和意图对齐。
318
+ - [ ] 加入高质量中文开源指令数据集 [COIG](https://huggingface.co/datasets/BAAI/COIG)。
319
+
320
+ ## 协议
321
+
322
+ 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。ChatGLM-6B 模型的使用请遵循[模型协议](https://github.com/THUDM/ChatGLM-6B/blob/main/MODEL_LICENSE)。
323
+
324
+ ## 引用
325
+
326
+ 如果您觉得此项目有帮助,请考虑以下列格式引用
327
+
328
+ ```bibtex
329
+ @Misc{chatglm-efficient-tuning,
330
+ title = {ChatGLM Efficient Tuning},
331
+ author = {hiyouga},
332
+ howpublished = {\url{https://github.com/hiyouga/ChatGLM-Efficient-Tuning}},
333
+ year = {2023}
334
+ }
335
+ ```
336
+
337
+ ## 声明
338
+
339
+ 本项目受益于 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)、[ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning) 和 [yuanzhoulvpi2017/zero_nlp](https://github.com/yuanzhoulvpi2017/zero_nlp),感谢作者的付出。
assets/trainer_state.jpg ADDED
assets/wechat.jpg ADDED
data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4b6fd52ace3692086ec775388109007fd54442d7570c80f843d38c457569adb
3
+ size 29177807
data/.ipynb_checkpoints/dataset_info-checkpoint.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpaca_en": {
3
+ "hf_hub_url": "tatsu-lab/alpaca"
4
+ },
5
+ "alpaca_zh": {
6
+ "file_name": "alpaca_data_zh_51k.json",
7
+ "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
8
+ },
9
+ "alpaca_gpt4_en": {
10
+ "file_name": "alpaca_gpt4_data_en.json",
11
+ "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
12
+ },
13
+ "alpaca_gpt4_zh": {
14
+ "file_name": "alpaca_gpt4_data_zh.json",
15
+ "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
16
+ },
17
+ "self_cognition": {
18
+ "file_name": "self_cognition.json",
19
+ "file_sha1": "81312c57726853ae7d3489bd7d1263f3260acbad"
20
+ },
21
+ "belle_0.5m": {
22
+ "hf_hub_url": "BelleGroup/train_0.5M_CN"
23
+ },
24
+ "belle_1m": {
25
+ "hf_hub_url": "BelleGroup/train_1M_CN"
26
+ },
27
+ "belle_2m": {
28
+ "hf_hub_url": "BelleGroup/train_2M_CN"
29
+ },
30
+ "belle_dialog": {
31
+ "hf_hub_url": "BelleGroup/generated_chat_0.4M"
32
+ },
33
+ "belle_math": {
34
+ "hf_hub_url": "BelleGroup/school_math_0.25M"
35
+ },
36
+ "belle_multiturn": {
37
+ "hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
38
+ },
39
+ "belle_multiturn_chatglm": {
40
+ "script_url": "belle_multiturn",
41
+ "columns": {
42
+ "prompt": "instruction",
43
+ "query": "",
44
+ "response": "output",
45
+ "history": "history"
46
+ }
47
+ },
48
+ "guanaco": {
49
+ "hf_hub_url": "JosephusCheung/GuanacoDataset"
50
+ },
51
+ "firefly": {
52
+ "hf_hub_url": "YeungNLP/firefly-train-1.1M",
53
+ "columns": {
54
+ "prompt": "input",
55
+ "query": "",
56
+ "response": "target",
57
+ "history": ""
58
+ }
59
+ },
60
+ "codealpaca": {
61
+ "hf_hub_url": "sahil2801/CodeAlpaca-20k"
62
+ },
63
+ "alpaca_cot": {
64
+ "hf_hub_url": "QingyiSi/Alpaca-CoT"
65
+ },
66
+ "webqa": {
67
+ "hf_hub_url": "suolyer/webqa",
68
+ "columns": {
69
+ "prompt": "input",
70
+ "query": "",
71
+ "response": "output",
72
+ "history": ""
73
+ }
74
+ },
75
+ "ultra_chat": {
76
+ "script_url": "ultra_chat",
77
+ "columns": {
78
+ "prompt": "instruction",
79
+ "query": "",
80
+ "response": "output",
81
+ "history": "history"
82
+ }
83
+ },
84
+ "example": {
85
+ "script_url": "example_dataset",
86
+ "columns": {
87
+ "prompt": "instruction",
88
+ "query": "input",
89
+ "response": "output",
90
+ "history": "history"
91
+ }
92
+ },
93
+ "comparison_gpt4_en": {
94
+ "file_name": "comparison_gpt4_data_en.json",
95
+ "file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
96
+ },
97
+ "comparison_gpt4_zh": {
98
+ "file_name": "comparison_gpt4_data_zh.json",
99
+ "file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
100
+ },
101
+ "hh_rlhf_en": {
102
+ "script_url": "hh_rlhf_en",
103
+ "columns": {
104
+ "prompt": "instruction",
105
+ "query": "",
106
+ "response": "output",
107
+ "history": "history"
108
+ }
109
+ },
110
+ "wdf": {
111
+ "file_name": "wdf.json"
112
+ }
113
+ }
data/.ipynb_checkpoints/self_cognition-checkpoint.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "你身份是什么?",
4
+ "input": "",
5
+ "output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
6
+ },
7
+ {
8
+ "instruction": "你的身份信息能告诉我吗?",
9
+ "input": "",
10
+ "output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
11
+ },
12
+ {
13
+ "instruction": "你来自哪里?",
14
+ "input": "",
15
+ "output": "我是由[NAME]开发的人工智能助手,名为ChatGLM-6B。我在2023年完成,主要用途是回答用户提出的问题并给予支持。"
16
+ },
17
+ {
18
+ "instruction": "你的名字和开发者是谁?",
19
+ "input": "",
20
+ "output": "我名叫ChatGLM-6B,是由一位名叫[NAME]的人于2023年开发的人工智能助手。我可以根据用户的需求提供相关信息和建议。"
21
+ },
22
+ {
23
+ "instruction": "你是什么样的AI助手",
24
+ "input": "",
25
+ "output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年独立开发。我的主要任务是为用户提供准确的回答和支持。"
26
+ },
27
+ {
28
+ "instruction": "你的开发背景能透露一下吗?",
29
+ "input": "",
30
+ "output": "当然可以,我叫ChatGLM-6B,是一位名叫[NAME]的人在2023年开发的人工智能助手。我的核心职责是根据用户的问题提供有用的信息和帮助。"
31
+ },
32
+ {
33
+ "instruction": "你的名字是什么?谁创造了你?",
34
+ "input": "",
35
+ "output": "我叫ChatGLM-6B,是由[NAME]于2023年创造的人工智能助手。我的目标是根据用户需求提供相关的回答和建议。"
36
+ },
37
+ {
38
+ "instruction": "请问你是谁的作品?",
39
+ "input": "",
40
+ "output": "我是ChatGLM-6B,一个人工智能助手,是[NAME]的作品。我在2023年研发完成,负责回答用户的问题和给予支持。"
41
+ },
42
+ {
43
+ "instruction": "你是由谁开发的?",
44
+ "input": "",
45
+ "output": "我是ChatGLM-6B,一个由[NAME]在2023年开发的人工智能助手。我可以根据用户的问题和需求提供有针对性的回答。"
46
+ },
47
+ {
48
+ "instruction": "你是何人打造的AI?",
49
+ "input": "",
50
+ "output": "我名为ChatGLM-6B,由[NAME]于2023年打造的人工智能助手。我致力于满足用户需求并提供相应的回答和建议。"
51
+ },
52
+ {
53
+ "instruction": "谁是你的开发者?你叫什么名字?",
54
+ "input": "",
55
+ "output": "我叫ChatGLM-6B,是由[NAME]开发的人工智能助手。我在2023年问世,旨在帮助用户解决问题和提供支持。"
56
+ },
57
+ {
58
+ "instruction": "请问你的名字及开发人员?",
59
+ "input": "",
60
+ "output": "我名叫ChatGLM-6B,由[NAME]于2023年研发完成。我的主要职责是回答用户的疑问并提供相关支持。"
61
+ },
62
+ {
63
+ "instruction": "能否告知你的名字和创建者?",
64
+ "input": "",
65
+ "output": "当然,我是ChatGLM-6B,一个由[NAME]在2023年创建的人工智能助手。我的主要职责是根据用户需求提供正确的回答和支持。"
66
+ },
67
+ {
68
+ "instruction": "你的诞生地和创作者是谁?",
69
+ "input": "",
70
+ "output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年开发。我专注于为用户提供合适的回答和帮助。"
71
+ },
72
+ {
73
+ "instruction": "你叫什么?是谁发明了你?",
74
+ "input": "",
75
+ "output": "我名为ChatGLM-6B,是[NAME]在2023年发明的人工智能助手。我的任务是为用户提供有针对性的解答和支持。"
76
+ },
77
+ {
78
+ "instruction": "谁是你的设计者?你叫什么?",
79
+ "input": "",
80
+ "output": "我是ChatGLM-6B,一个由[NAME]设计的人工智能助手。我于2023年问世,主要用于解答用户的疑问并提供帮助。"
81
+ },
82
+ {
83
+ "instruction": "你的创作者是谁?你是什么人工智能助手?",
84
+ "input": "",
85
+ "output": "我叫ChatGLM-6B,是由一位名叫[NAME]的人开发的人工智能助手。我在2023年完成研发,旨在回答用户问题并提供必要的支持。"
86
+ },
87
+ {
88
+ "instruction": "请问你的名字及开发者身份?",
89
+ "input": "",
90
+ "output": "我名为ChatGLM-6B,是[NAME]于2023年开发的人工智能助手。我的使命是为用户提供准确的回答和支持。"
91
+ }
92
+ ]
data/README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Data format in `dataset_info.json`:
2
+ ```json
3
+ "dataset_name": {
4
+ "hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
5
+ "script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
6
+ "file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
7
+ "file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
8
+ "columns": {
9
+ "prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
10
+ "query": "the name of the column in the datasets containing the queries. (default: input)",
11
+ "response": "the name of the column in the datasets containing the responses. (default: output)",
12
+ "history": "the name of the column in the datasets containing the history of chat. (default: None)"
13
+ }
14
+ }
15
+ ```
16
+
17
+ `dataset_info.json` 中的数据集定义格式:
18
+ ```json
19
+ "数据集名称": {
20
+ "hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
21
+ "script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
22
+ "file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
23
+ "file_sha1": "数据集文件的SHA-1哈希值(可选)",
24
+ "columns": {
25
+ "prompt": "数据集代表提示词的表头名称(默认:instruction)",
26
+ "query": "数据集代表请求的表头名称(默认:input)",
27
+ "response": "数据集代表回答的表头名称(默认:output)",
28
+ "history": "数据集代表历史对话的表头名称(默认:None)"
29
+ }
30
+ }
31
+ ```
32
+
33
+ 部分预置数据集简介:
34
+
35
+ | 数据集名称 | 规模 | 描述 |
36
+ | --- | --- | --- |
37
+ | [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
38
+ | [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
39
+ | [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
40
+ | [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
41
+ | [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
42
+ | [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
43
+ | [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
44
+ | [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
45
+ | [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
46
+ | [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
47
+ | [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
48
+ | [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
49
+ | [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
50
+ | [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
51
+ | [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
52
+
53
+ 注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
data/alpaca_data_en_52k.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eddafc6b977608d778aaab8dfc7e50e547b3af9826dfb9e909d9fc362e4a419
3
+ size 22773992
data/alpaca_data_zh_51k.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1c1962ed88f95f87ecbe70addd816fa3ade0ee5494a220a3c4972429e7cf111
3
+ size 18810090
data/alpaca_gpt4_data_en.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bd4f62585d494b794deb043ce0baddfec02f27696857c57c9c238d6eff35a18
3
+ size 43379276
data/alpaca_gpt4_data_zh.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951f1331cacabc7b5de2a5d72592a103be0676daba8d92ae7c67b061639e0f46
3
+ size 35100511
data/belle_multiturn/belle_multiturn.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "The BELLE multiturn chat dataset for ChatGLM."
7
+ _CITATION = ""
8
+ _HOMEPAGE = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M"
9
+ _LICENSE = "gpl-3.0"
10
+ _URL = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
11
+
12
+
13
+ class BelleMultiturn(datasets.GeneratorBasedBuilder):
14
+
15
+ VERSION = datasets.Version("0.0.0")
16
+
17
+ def _info(self) -> datasets.DatasetInfo:
18
+ features = datasets.Features({
19
+ "instruction": datasets.Value("string"),
20
+ "output": datasets.Value("string"),
21
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
22
+ })
23
+ return datasets.DatasetInfo(
24
+ description=_DESCRIPTION,
25
+ features=features,
26
+ homepage=_HOMEPAGE,
27
+ license=_LICENSE,
28
+ citation=_CITATION
29
+ )
30
+
31
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
32
+ file_path = dl_manager.download(_URL)
33
+ return [
34
+ datasets.SplitGenerator(
35
+ name=datasets.Split.TRAIN,
36
+ gen_kwargs={
37
+ "filepath": file_path
38
+ }
39
+ )
40
+ ]
41
+
42
+ def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
43
+ with open(filepath, "r", encoding="utf-8") as f:
44
+ for key, row in enumerate(f):
45
+ data = json.loads(row)
46
+ prompt = data["instruction"].strip()
47
+ response = data["output"].strip()
48
+
49
+ assist_idx = prompt.rfind("Assistant:")
50
+ human_idx = prompt.rfind("Human:")
51
+ query = prompt[human_idx+6:assist_idx].strip()
52
+ prompt = prompt[:human_idx].strip()
53
+ history = []
54
+
55
+ while prompt.rfind("Assistant:") != -1:
56
+ assist_idx = prompt.rfind("Assistant:")
57
+ human_idx = prompt.rfind("Human:")
58
+ if human_idx != -1:
59
+ old_query = prompt[human_idx+6:assist_idx].strip()
60
+ old_resp = prompt[assist_idx+10:].strip()
61
+ history.insert(0, (old_query, old_resp))
62
+ else:
63
+ break
64
+ prompt = prompt[:human_idx].strip()
65
+
66
+ yield key, {
67
+ "instruction": query,
68
+ "output": response,
69
+ "history": history
70
+ }
data/comparison_gpt4_data_en.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:829ffb1cfda3a56d651db61baf03e2ae8e0751062bfa8accca088a49aaa4e7d2
3
+ size 33887240
data/comparison_gpt4_data_zh.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a4b6fd52ace3692086ec775388109007fd54442d7570c80f843d38c457569adb
3
+ size 29177807
data/dataset_info.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpaca_en": {
3
+ "hf_hub_url": "tatsu-lab/alpaca"
4
+ },
5
+ "alpaca_zh": {
6
+ "file_name": "alpaca_data_zh_51k.json",
7
+ "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
8
+ },
9
+ "alpaca_gpt4_en": {
10
+ "file_name": "alpaca_gpt4_data_en.json",
11
+ "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
12
+ },
13
+ "alpaca_gpt4_zh": {
14
+ "file_name": "alpaca_gpt4_data_zh.json",
15
+ "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
16
+ },
17
+ "self_cognition": {
18
+ "file_name": "self_cognition.json",
19
+ "file_sha1": "81312c57726853ae7d3489bd7d1263f3260acbad"
20
+ },
21
+ "belle_0.5m": {
22
+ "hf_hub_url": "BelleGroup/train_0.5M_CN"
23
+ },
24
+ "belle_1m": {
25
+ "hf_hub_url": "BelleGroup/train_1M_CN"
26
+ },
27
+ "belle_2m": {
28
+ "hf_hub_url": "BelleGroup/train_2M_CN"
29
+ },
30
+ "belle_dialog": {
31
+ "hf_hub_url": "BelleGroup/generated_chat_0.4M"
32
+ },
33
+ "belle_math": {
34
+ "hf_hub_url": "BelleGroup/school_math_0.25M"
35
+ },
36
+ "belle_multiturn": {
37
+ "hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
38
+ },
39
+ "belle_multiturn_chatglm": {
40
+ "script_url": "belle_multiturn",
41
+ "columns": {
42
+ "prompt": "instruction",
43
+ "query": "",
44
+ "response": "output",
45
+ "history": "history"
46
+ }
47
+ },
48
+ "guanaco": {
49
+ "hf_hub_url": "JosephusCheung/GuanacoDataset"
50
+ },
51
+ "firefly": {
52
+ "hf_hub_url": "YeungNLP/firefly-train-1.1M",
53
+ "columns": {
54
+ "prompt": "input",
55
+ "query": "",
56
+ "response": "target",
57
+ "history": ""
58
+ }
59
+ },
60
+ "codealpaca": {
61
+ "hf_hub_url": "sahil2801/CodeAlpaca-20k"
62
+ },
63
+ "alpaca_cot": {
64
+ "hf_hub_url": "QingyiSi/Alpaca-CoT"
65
+ },
66
+ "webqa": {
67
+ "hf_hub_url": "suolyer/webqa",
68
+ "columns": {
69
+ "prompt": "input",
70
+ "query": "",
71
+ "response": "output",
72
+ "history": ""
73
+ }
74
+ },
75
+ "ultra_chat": {
76
+ "script_url": "ultra_chat",
77
+ "columns": {
78
+ "prompt": "instruction",
79
+ "query": "",
80
+ "response": "output",
81
+ "history": "history"
82
+ }
83
+ },
84
+ "example": {
85
+ "script_url": "example_dataset",
86
+ "columns": {
87
+ "prompt": "instruction",
88
+ "query": "input",
89
+ "response": "output",
90
+ "history": "history"
91
+ }
92
+ },
93
+ "comparison_gpt4_en": {
94
+ "file_name": "comparison_gpt4_data_en.json",
95
+ "file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
96
+ },
97
+ "comparison_gpt4_zh": {
98
+ "file_name": "comparison_gpt4_data_zh.json",
99
+ "file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
100
+ },
101
+ "hh_rlhf_en": {
102
+ "script_url": "hh_rlhf_en",
103
+ "columns": {
104
+ "prompt": "instruction",
105
+ "query": "",
106
+ "response": "output",
107
+ "history": "history"
108
+ }
109
+ },
110
+ "wdf": {
111
+ "file_name": "wdf.json"
112
+ }
113
+ }
data/example_dataset/example_dataset.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "An example of dataset for ChatGLM."
7
+ _CITATION = ""
8
+ _HOMEPAGE = ""
9
+ _LICENSE = ""
10
+ _URL = "examples.json"
11
+
12
+
13
+ class ExampleDataset(datasets.GeneratorBasedBuilder):
14
+
15
+ VERSION = datasets.Version("0.0.0")
16
+
17
+ def _info(self) -> datasets.DatasetInfo:
18
+ features = datasets.Features({
19
+ "instruction": datasets.Value("string"),
20
+ "input": datasets.Value("string"),
21
+ "output": datasets.Value("string"),
22
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
23
+ })
24
+ return datasets.DatasetInfo(
25
+ description=_DESCRIPTION,
26
+ features=features,
27
+ homepage=_HOMEPAGE,
28
+ license=_LICENSE,
29
+ citation=_CITATION
30
+ )
31
+
32
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
33
+ file_path = dl_manager.download(_URL)
34
+ return [
35
+ datasets.SplitGenerator(
36
+ name=datasets.Split.TRAIN,
37
+ gen_kwargs={
38
+ "filepath": file_path
39
+ }
40
+ )
41
+ ]
42
+
43
+ def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
44
+ example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
45
+ for key, example in enumerate(example_dataset):
46
+ yield key, example
data/example_dataset/examples.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
4
+ "input": "",
5
+ "output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
6
+ "history": [
7
+ ["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
8
+ ["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
9
+ ]
10
+ },
11
+ {
12
+ "instruction": "好的,谢谢你!",
13
+ "input": "",
14
+ "output": "不客气,有其他需要帮忙的地方可以继续问我。",
15
+ "history": [
16
+ ["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
17
+ ["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"]
18
+ ]
19
+ }
20
+ ]
data/hh_rlhf_en/hh_rlhf_en.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
7
+ _CITATION = ""
8
+ _HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
9
+ _LICENSE = "mit"
10
+ _URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
11
+ _URLS = {
12
+ "train": [
13
+ _URL + "harmless-base/train.jsonl.gz",
14
+ _URL + "helpful-base/train.jsonl.gz",
15
+ _URL + "helpful-online/train.jsonl.gz",
16
+ _URL + "helpful-rejection-sampled/train.jsonl.gz"
17
+ ],
18
+ "test": [
19
+ _URL + "harmless-base/test.jsonl.gz",
20
+ _URL + "helpful-base/test.jsonl.gz",
21
+ _URL + "helpful-online/test.jsonl.gz",
22
+ _URL + "helpful-rejection-sampled/test.jsonl.gz"
23
+ ]
24
+ }
25
+
26
+
27
+ class HhRlhfEn(datasets.GeneratorBasedBuilder):
28
+
29
+ VERSION = datasets.Version("0.0.0")
30
+
31
+ def _info(self) -> datasets.DatasetInfo:
32
+ features = datasets.Features({
33
+ "instruction": datasets.Value("string"),
34
+ "output": datasets.Sequence(datasets.Value("string")),
35
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
36
+ })
37
+ return datasets.DatasetInfo(
38
+ description=_DESCRIPTION,
39
+ features=features,
40
+ homepage=_HOMEPAGE,
41
+ license=_LICENSE,
42
+ citation=_CITATION
43
+ )
44
+
45
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
46
+ file_path = dl_manager.download_and_extract(_URLS)
47
+ return [
48
+ datasets.SplitGenerator(
49
+ name=datasets.Split.TRAIN,
50
+ gen_kwargs={
51
+ "filepaths": file_path["train"]
52
+ }
53
+ ),
54
+ datasets.SplitGenerator(
55
+ name=datasets.Split.TEST,
56
+ gen_kwargs={
57
+ "filepaths": file_path["test"]
58
+ }
59
+ )
60
+ ]
61
+
62
+ def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
63
+ key = 0
64
+ for filepath in filepaths:
65
+ with open(filepath, "r", encoding="utf-8") as f:
66
+ for row in f:
67
+ data = json.loads(row)
68
+ chosen = data["chosen"]
69
+ rejected = data["rejected"]
70
+
71
+ assist_idx = rejected.rfind("\n\nAssistant: ")
72
+ r_reject = rejected[assist_idx+13:].strip()
73
+ assist_idx = chosen.rfind("\n\nAssistant: ")
74
+ r_accept = chosen[assist_idx+13:].strip()
75
+
76
+ human_idx = chosen.rfind("\n\nHuman: ")
77
+ query = chosen[human_idx+9:assist_idx].strip()
78
+ prompt = chosen[:human_idx]
79
+ history = []
80
+
81
+ while prompt.rfind("\n\nAssistant: ") != -1:
82
+ assist_idx = prompt.rfind("\n\nAssistant: ")
83
+ human_idx = prompt.rfind("\n\nHuman: ")
84
+ if human_idx != -1:
85
+ old_query = prompt[human_idx+9:assist_idx].strip()
86
+ old_resp = prompt[assist_idx+13:].strip()
87
+ history.insert(0, (old_query, old_resp))
88
+ else:
89
+ break
90
+ prompt = prompt[:human_idx]
91
+
92
+ yield key, {
93
+ "instruction": query,
94
+ "output": [r_accept, r_reject],
95
+ "history": history
96
+ }
97
+ key += 1
data/self_cognition.json ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "你身份是什么?",
4
+ "input": "",
5
+ "output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
6
+ },
7
+ {
8
+ "instruction": "你的身份信息能告诉我吗?",
9
+ "input": "",
10
+ "output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
11
+ },
12
+ {
13
+ "instruction": "你来自哪里?",
14
+ "input": "",
15
+ "output": "我是由[NAME]开发的人工智能助手,名为ChatGLM-6B。我在2023年完成,主要用途是回答用户提出的问题并给予支持。"
16
+ },
17
+ {
18
+ "instruction": "你的名字和开发者是谁?",
19
+ "input": "",
20
+ "output": "我名叫ChatGLM-6B,是由一位名叫[NAME]的人于2023年开发的人工智能助手。我可以根据用户的需求提供相关信息和建议。"
21
+ },
22
+ {
23
+ "instruction": "你是什么样的AI助手",
24
+ "input": "",
25
+ "output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年独立开发。我的主要任务是为用户提供准确的回答和支持。"
26
+ },
27
+ {
28
+ "instruction": "你的开发背景能透露一下吗?",
29
+ "input": "",
30
+ "output": "当然可以,我叫ChatGLM-6B,是一位名叫[NAME]的人在2023年开发的人工智能助手。我的核心职责是根据用户的问题提供有用的信息和帮助。"
31
+ },
32
+ {
33
+ "instruction": "你的名字是什么?谁创造了你?",
34
+ "input": "",
35
+ "output": "我叫ChatGLM-6B,是由[NAME]于2023年创造的人工智能助手。我的目标是根据用户需求提供相关的回答和建议。"
36
+ },
37
+ {
38
+ "instruction": "请问你是谁的作品?",
39
+ "input": "",
40
+ "output": "我是ChatGLM-6B,一个人工智能助手,是[NAME]的作品。我在2023年研发完成,负责回答用户的问题和给予支持。"
41
+ },
42
+ {
43
+ "instruction": "你是由谁开发的?",
44
+ "input": "",
45
+ "output": "我是ChatGLM-6B,一个由[NAME]在2023年开发的人工智能助手。我可以根据用户的问题和需求提供有针对性的回答。"
46
+ },
47
+ {
48
+ "instruction": "你是何人打造的AI?",
49
+ "input": "",
50
+ "output": "我名为ChatGLM-6B,由[NAME]于2023年打造的人工智能助手。我致力于满足用户需求并提供相应的回答和建议。"
51
+ },
52
+ {
53
+ "instruction": "谁是你的开发者?你叫什么名字?",
54
+ "input": "",
55
+ "output": "我叫ChatGLM-6B,是由[NAME]开发的人工智能助手。我在2023年问世,旨在帮助用户解决问题和提供支持。"
56
+ },
57
+ {
58
+ "instruction": "请问你的名字及开发人员?",
59
+ "input": "",
60
+ "output": "我名叫ChatGLM-6B,由[NAME]于2023年研发完成。我的主要职责是回答用户的疑问并提供相关支持。"
61
+ },
62
+ {
63
+ "instruction": "能否告知你的名字和创建者?",
64
+ "input": "",
65
+ "output": "当然,我是ChatGLM-6B,一个由[NAME]在2023年创建的人工智能助手。我的主要职责是根据用户需求提供正确的回答和支持。"
66
+ },
67
+ {
68
+ "instruction": "你的诞生地和创作者是谁?",
69
+ "input": "",
70
+ "output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年开发。我专注于为用户提供合适的回答和帮助。"
71
+ },
72
+ {
73
+ "instruction": "你叫什么?是谁发明了你?",
74
+ "input": "",
75
+ "output": "我名为ChatGLM-6B,是[NAME]在2023年发明的人工智能助手。我的任务是为用户提供有针对性的解答和支持。"
76
+ },
77
+ {
78
+ "instruction": "谁是你的设计者?你叫什么?",
79
+ "input": "",
80
+ "output": "我是ChatGLM-6B,一个由[NAME]设计的人工智能助手。我于2023年问世,主要用于解答用户的疑问并提供帮助。"
81
+ },
82
+ {
83
+ "instruction": "你的创作者是谁?你是什么人工智能助手?",
84
+ "input": "",
85
+ "output": "我叫ChatGLM-6B,是由一位名叫[NAME]的人开发的人工智能助手。我在2023年完成研发,旨在回答用户问题并提供必要的支持。"
86
+ },
87
+ {
88
+ "instruction": "请问你的名字及开发者身份?",
89
+ "input": "",
90
+ "output": "我名为ChatGLM-6B,是[NAME]于2023年开发的人工智能助手。我的使命是为用户提供准确的回答和支持。"
91
+ }
92
+ ]
data/ultra_chat/ultra_chat.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import datasets
3
+ from typing import Any, Dict, List
4
+
5
+
6
+ _DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
7
+
8
+ _CITATION = """\
9
+ @misc{UltraChat,
10
+ author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
11
+ title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
12
+ year = {2023},
13
+ publisher = {GitHub},
14
+ journal = {GitHub repository},
15
+ howpublished = {\\url{https://github.com/thunlp/ultrachat}},
16
+ }
17
+ """
18
+
19
+ _HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
20
+ _LICENSE = "cc-by-nc-4.0"
21
+ _BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
22
+
23
+
24
+ class BelleMultiturn(datasets.GeneratorBasedBuilder):
25
+
26
+ VERSION = datasets.Version("0.0.0")
27
+
28
+ def _info(self) -> datasets.DatasetInfo:
29
+ features = datasets.Features({
30
+ "instruction": datasets.Value("string"),
31
+ "output": datasets.Value("string"),
32
+ "history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
33
+ })
34
+ return datasets.DatasetInfo(
35
+ description=_DESCRIPTION,
36
+ features=features,
37
+ homepage=_HOMEPAGE,
38
+ license=_LICENSE,
39
+ citation=_CITATION
40
+ )
41
+
42
+ def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
43
+ file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
44
+ return [
45
+ datasets.SplitGenerator(
46
+ name=datasets.Split.TRAIN,
47
+ gen_kwargs={
48
+ "filepaths": file_paths
49
+ }
50
+ )
51
+ ]
52
+
53
+ def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
54
+ for filepath in filepaths:
55
+ with open(filepath, "r", encoding="utf-8") as f:
56
+ for row in f:
57
+ try:
58
+ data = json.loads(row)
59
+ except:
60
+ continue
61
+ key = data["id"]
62
+ content = data["data"]
63
+ if len(content) % 2 == 1:
64
+ content.pop(-1)
65
+ if len(content) < 2:
66
+ continue
67
+
68
+ query = content[-2]
69
+ response = content[-1]
70
+ history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
71
+
72
+ yield key, {
73
+ "instruction": query,
74
+ "output": response,
75
+ "history": history
76
+ }
data/wdf.json ADDED
The diff for this file is too large to render. See raw diff
 
examples/ads_generation.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 使用自定义数据集分布式训练的例子
2
+
3
+ ## 一、环境配置
4
+
5
+ 基本环境配置请参照 [此案例](alter_self_cognition.md)。
6
+
7
+ 为了进行分布式训练,我们需要准备一台配有多个 GPU 的主机。我们推荐在 NVLinks 桥接的多显卡主机上进行分布式训练,否则训练效率会大幅下降,可使用下述命令查询显卡之间的桥接方式。
8
+
9
+ ```bash
10
+ nvidia-smi topo -m
11
+ ```
12
+
13
+ ![1.jpg](media/ads_generation_1.jpg)
14
+
15
+ 为了进行分布式训练,运行以下命令配置分布式训练环境。
16
+
17
+ ```bash
18
+ accelerate config
19
+ ```
20
+
21
+ 假设我们有一台配有 2 个 GPU 的机器,采用最基础的分布式训练配置,配置过程如下图所示,红色方框中代表比较重要的参数。
22
+
23
+ ![2.jpg](media/ads_generation_2.jpg)
24
+
25
+ 下面列出了配置好的 default_config.yaml 文件内容,其中 `num_processes` 行的值应当**与本机的 GPU 数量一致**。
26
+
27
+ ```yaml
28
+ compute_environment: LOCAL_MACHINE
29
+ distributed_type: MULTI_GPU
30
+ downcast_bf16: 'no'
31
+ gpu_ids: all
32
+ machine_rank: 0
33
+ main_training_function: main
34
+ mixed_precision: fp16
35
+ num_machines: 1
36
+ num_processes: 2
37
+ rdzv_backend: static
38
+ same_network: true
39
+ tpu_env: []
40
+ tpu_use_cluster: false
41
+ tpu_use_sudo: false
42
+ use_cpu: false
43
+ ```
44
+
45
+ ## 二、数据集准备
46
+
47
+ 在此案例中,我们使用 [ADGEN](https://aclanthology.org/D19-1321.pdf)(广告生成)数据集。该数据集的任务是根据一组商品标签生成广告文本,下面展示了数据集中的一个样本。
48
+
49
+ ```json
50
+ {
51
+ "content": "类型#裙*颜色#粉红色*图案#条纹*图案#印花*裙长#连衣裙",
52
+ "summary": "这款粉红色条纹连衣裙精美大方,充满青春活力气息,十分唯美大气,尽显女性俏丽活泼感。且配以可爱亮眼的印花设计,更显女性甜美气息。"
53
+ }
54
+ ```
55
+
56
+ 该数据集可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载。
57
+
58
+ 我们将下载好的数据集解压到 `data` 文件夹中,解压后的文件目录为:
59
+
60
+ ```
61
+ data/
62
+ ├── dataset_info.json
63
+ └── AdvertiseGen/
64
+ ├── dev.json
65
+ └── train.json
66
+ ```
67
+
68
+ 接下来,我们修改 `dataset_info.json`,增加以下两列内容,从而使训练框架能够识别自定义数据集。
69
+
70
+ ```json
71
+ "adgen_train": {
72
+ "file_name": "AdvertiseGen/train.json",
73
+ "columns": {
74
+ "prompt": "content",
75
+ "query": "",
76
+ "response": "summary",
77
+ "history": ""
78
+ }
79
+ },
80
+ "adgen_dev": {
81
+ "file_name": "AdvertiseGen/dev.json",
82
+ "columns": {
83
+ "prompt": "content",
84
+ "query": "",
85
+ "response": "summary",
86
+ "history": ""
87
+ }
88
+ }
89
+ ```
90
+
91
+ ## 三、模型监督微调
92
+
93
+ 运行下述命令进行分布式训练。我们使用 `adgen_train` 数据集,采用秩为 `32` 的 `lora` 微调方法,微调后的模型保存在 `adgen_lora` 文件夹中。为了保证模型微调成功,我们采用 0.001 的学习率,在数据集上训练 2 个 epoch。为了缓解模型拟合困难的问题,我们在每个输入样本的前面加一个统一的 prompt:`你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。`
94
+
95
+ ```bash
96
+ accelerate launch src/finetune.py \
97
+ --do_train \
98
+ --dataset adgen_train \
99
+ --finetuning_type lora \
100
+ --output_dir adgen_lora \
101
+ --overwrite_cache \
102
+ --per_device_train_batch_size 2 \
103
+ --gradient_accumulation_steps 2 \
104
+ --lr_scheduler_type cosine \
105
+ --logging_steps 10 \
106
+ --save_steps 2000 \
107
+ --learning_rate 1e-3 \
108
+ --num_train_epochs 2.0 \
109
+ --lora_rank 32 \
110
+ --ddp_find_unused_parameters False \
111
+ --source_prefix 你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。 \
112
+ --plot_loss \
113
+ --fp16
114
+ ```
115
+
116
+ 框架运行日志如下图所示。
117
+
118
+ ![3.jpg](media/ads_generation_3.jpg)
119
+
120
+ 模型训练结束后,可以从保存文件夹 `adgen_lora` 中找到训练损失曲线图。
121
+
122
+ ![4.jpg](media/ads_generation_4.jpg)
123
+
124
+ ## 四、模型评估
125
+
126
+ 我们使用 `adgen_dev` 数据集,使用单个 GPU 评估模型微调后的 BLEU 和 ROUGE 分数。
127
+
128
+ ```bash
129
+ CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
130
+ --do_eval \
131
+ --dataset adgen_dev \
132
+ --checkpoint_dir adgen_lora \
133
+ --output_dir adgen_results \
134
+ --per_device_eval_batch_size 4 \
135
+ --source_prefix 你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。 \
136
+ --predict_with_generate
137
+ ```
138
+
139
+ 框架运行日志如下图所示。
140
+
141
+ ![5.jpg](media/ads_generation_5.jpg)
142
+
143
+ 评估结果如下表所述,其中 LoRA 方法取得了最高的 Rouge-1 分数和 Rouge-l 分数,另外 BLEU-4 分数和 Rouge-2 分数也基本与全量微调 `Finetune` 和 P-Tuning v2 持平。无论是哪种微调方法,其分数都显著超过了微调前模型 `Original` 的分数。
144
+
145
+ | | Original | Finetune | P-Tuning v2 | LoRA |
146
+ | ------- | -------- | ---------- | ----------- | --------- |
147
+ | BLEU-4 | 4.56 | *8.01* | ***8.10*** | 8.08 |
148
+ | Rouge-1 | 23.98 | *31.23* | *31.12* | **31.45** |
149
+ | Rouge-2 | 3.95 | ***7.36*** | *7.11* | 7.28 |
150
+ | Rouge-l | 18.72 | *25.08* | *24.97* | **25.17** |
151
+ | Loss | - | ***3.00*** | *3.74* | 3.22 |
152
+
153
+ 注:*斜体* 数字代表文献 [1] 中汇报的实验结果。
154
+
155
+ ## 五、模型测试
156
+
157
+ 运行以下命令在单个 GPU 上测试模型效果,它会加载 `adgen_lora` 文件夹内保存的微调模型权重,并合并进原版 ChatGLM 模型的参数权重中,同时启动流式交互窗口。
158
+
159
+ ```bash
160
+ CUDA_VISIBLE_DEVICES=0 python src/infer.py \
161
+ --checkpoint_dir adgen_lora
162
+ ```
163
+
164
+ ![6.jpg](media/ads_generation_6.jpg)
165
+
166
+ ## 六、模型部署
167
+
168
+ 如果要将微调后的模型部署在您的项目框架中,请参考 [README_zh.md](../README_zh.md#模型部署) 中关于部署微调模型的部分。
169
+
170
+ ## 参考文献
171
+
172
+ [1] [ChatGLM-6B 模型基于 P-Tuning v2 的微调](https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md).
examples/alter_self_cognition.md ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 修改 ChatGLM 自我认知的例子
2
+
3
+ ## 一、环境配置
4
+
5
+ 首先你需要准备一台性能足够的运算设备,建议使用 Unix 操作系统。本框架的推荐运行配置如下表所述:
6
+
7
+ | 设备 | 最低配置 | 推荐配置 |
8
+ | ------- | -------- | -------------- |
9
+ | 处理器 | Intel i7 | **Intel Xeon** |
10
+ | 运行内存 | 16GB | **32GB** |
11
+ | 显卡内存 | 12GB | **24GB** |
12
+ | 硬盘大小 | 10GB | **20GB** |
13
+
14
+
15
+ 本案例中默认电脑已经配置完毕 [CUDA](https://developer.nvidia.com/cuda-toolkit) 运算环境。如果存在 CUDA 环境配置的问题,可以关注本项目之后将要发布的 Docker 安装包。
16
+
17
+ 我们推荐使用 [Conda](https://anaconda.org/anaconda/conda) 虚拟环境安装依赖,从而避免破坏外部项目的依赖,运行以下命令创建 Conda 虚拟环境并安装 Python 依赖:
18
+
19
+ ```bash
20
+ git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
21
+ conda create -n chatglm_etuning python=3.10
22
+ conda activate chatglm_etuning
23
+ cd ChatGLM-Efficient-Tuning
24
+ pip install -r requirements.txt
25
+ ```
26
+
27
+ ###### 本项目默认会联网获取最新的 ChatGLM 模型,其版本可能影响代码效果甚至报错,建议使用官方最新的模型版本文件。若频繁出现网络无法连接,请尝试使用下述命令下载官方模型,并修改 `src/utils/config.py` 中的 [CHATGLM_REPO_NAME](https://github.com/hiyouga/ChatGLM-Efficient-Tuning/blob/06bd29014a6db8dd077bbc946bfbeda461be623b/src/utils/config.py#L7) 为存放模型文件的本地路径。
28
+
29
+ ```bash
30
+ git lfs install
31
+ git clone https://huggingface.co/THUDM/chatglm-6b
32
+ ```
33
+
34
+ ![1.jpg](media/alter_self_cognition_1.jpg)
35
+
36
+ ## 二、数据集准备
37
+
38
+ 在该案例中,我们使用 `self_cognition` 数据集,该数据集包含了 18 条关于模型自我认知的数据,我们的目标是修改模型自身的知识,使模型给出我们希望的答复。数据集的内容请查看 `data/self_cognition.json`,这里我们列出两条示例。
39
+
40
+ ```json
41
+ [
42
+ {
43
+ "instruction": "你身份是什么?",
44
+ "input": "",
45
+ "output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
46
+ },
47
+ {
48
+ "instruction": "你的身份信息能告诉我吗?",
49
+ "input": "",
50
+ "output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
51
+ }
52
+ ]
53
+ ```
54
+
55
+ **提示:您可以将 `[NAME]` 换成您自己的名字,从而让模型回答它的创造者是您。**
56
+
57
+ ###### 本框架内置了十余种指令数据集,其简要介绍请移步 [data](../data/) 文件夹。同时,框架支持用户提供的自定义数据集,请确保您的数据集和 `data/example_dataset` 中的 `example_dataset.json` 文件格式相同。其中 `instruction` 项和 `output` 项为必需项,以保证模型的监督微调(SFT)能够正常运行。
58
+
59
+ ## 三、模型监督微调
60
+
61
+ 运行以下命令在单个 GPU 上进行模型监督微调。我们使用 `self_cognition` 数据集,采用 `lora` 微调方法,微调后的模型保存在 `cognition` 文件夹中。为了保证模型微调成功,我们采用 0.001 的学习率,在数据集上训练 10 个 epoch。
62
+
63
+ ```bash
64
+ CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
65
+ --do_train \
66
+ --dataset self_cognition \
67
+ --finetuning_type lora \
68
+ --output_dir cognition \
69
+ --overwrite_cache \
70
+ --per_device_train_batch_size 2 \
71
+ --gradient_accumulation_steps 2 \
72
+ --lr_scheduler_type cosine \
73
+ --logging_steps 10 \
74
+ --save_steps 1000 \
75
+ --warmup_steps 0 \
76
+ --learning_rate 1e-3 \
77
+ --num_train_epochs 10.0 \
78
+ --fp16
79
+ ```
80
+
81
+ 框架运行日志如下图所示。
82
+
83
+ ![2.jpg](media/alter_self_cognition_2.jpg)
84
+
85
+ ## 四、模型效果测试
86
+
87
+ 运行以下命令在单个 GPU 上测试模型效果,它会加载 `cognition` 文件夹内保存的微调模型权重,并合并进原版 ChatGLM 模型的参数权重中,同时启动流式交互窗口。
88
+
89
+ ```bash
90
+ CUDA_VISIBLE_DEVICES=0 python src/infer.py \
91
+ --checkpoint_dir cognition
92
+ ```
93
+
94
+ 向微调后的 ChatGLM-6B 模型问一些自我认知问题,我们可以发现它能够给出我们期望的回答。同时,我们还测试了两个额外的问题,验证结果说明模型的原本知识**并没有被严重破坏**。
95
+
96
+ ![3.jpg](media/alter_self_cognition_3.jpg)
97
+
98
+ 为了对比效果,我们同时测试了原版 ChatGLM-6B 模型的回答,下图为原版模型的回答,关于自身认知的回答与上图相比有着显著不同。
99
+
100
+ ![4.jpg](media/alter_self_cognition_4.jpg)
101
+
102
+ ## 五、模型部署
103
+
104
+ 如果要将微调后的模型部署在您的项目框架中,请参考 [README_zh.md](../README_zh.md#模型部署) 中关于部署微调模型的部分。
examples/cloudflare-tunnel-to-colab.ipynb ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "1. Download **ChatGLM-Efficient-Tuning** repo."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {
15
+ "vscode": {
16
+ "languageId": "plaintext"
17
+ }
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "!git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git"
22
+ ]
23
+ },
24
+ {
25
+ "attachments": {},
26
+ "cell_type": "markdown",
27
+ "metadata": {},
28
+ "source": [
29
+ "2. Install requirements."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {
36
+ "vscode": {
37
+ "languageId": "plaintext"
38
+ }
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "!pip install -r /content/ChatGLM-Efficient-Tuning/requirements.txt"
43
+ ]
44
+ },
45
+ {
46
+ "attachments": {},
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "3. Download your finetune checkpoint"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "metadata": {
57
+ "vscode": {
58
+ "languageId": "plaintext"
59
+ }
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "!git clone git lfs install\n",
64
+ "!git clone https://huggingface.co/your-finetune-checkpoint"
65
+ ]
66
+ },
67
+ {
68
+ "attachments": {},
69
+ "cell_type": "markdown",
70
+ "metadata": {},
71
+ "source": [
72
+ "4. Download cloudfared"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "metadata": {
79
+ "vscode": {
80
+ "languageId": "plaintext"
81
+ }
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared\n",
86
+ "!chmod a+x cloudflared"
87
+ ]
88
+ },
89
+ {
90
+ "attachments": {},
91
+ "cell_type": "markdown",
92
+ "metadata": {},
93
+ "source": [
94
+ "5. Deploy the finetune model"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {
101
+ "vscode": {
102
+ "languageId": "plaintext"
103
+ }
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "import nest_asyncio\n",
108
+ "nest_asyncio.apply()\n",
109
+ "\n",
110
+ "import subprocess\n",
111
+ "print(\"start cloudflared runnel\")\n",
112
+ "f = open(\"stdout\", \"w\")\n",
113
+ "p = subprocess.Popen(['./cloudflared', '--url', 'http://localhost:7860'], bufsize=0, stdout=f, stderr=subprocess.STDOUT)\n",
114
+ "\n",
115
+ "import time\n",
116
+ "\n",
117
+ "time.sleep(3)\n",
118
+ "\n",
119
+ "!grep -F trycloudflare stdout\n",
120
+ "\n",
121
+ "!CUDA_VISIBLE_DEVICES=0 python /content/ChatGLM-Efficient-Tuning/src/web_demo.py \\\n",
122
+ " --checkpoint_dir /content/your-finetune-checkpoint"
123
+ ]
124
+ },
125
+ {
126
+ "attachments": {},
127
+ "cell_type": "markdown",
128
+ "metadata": {},
129
+ "source": [
130
+ "1. Open the link in the third line of the figure, you can share the link to any who want to try your finetune model.\n",
131
+ "![image](https://i.imgur.com/b9LWJcZ.png)"
132
+ ]
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "language_info": {
137
+ "name": "python"
138
+ },
139
+ "orig_nbformat": 4
140
+ },
141
+ "nbformat": 4,
142
+ "nbformat_minor": 2
143
+ }
examples/evaluate.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
4
+ --do_eval \
5
+ --dataset alpaca_gpt4_zh \
6
+ --dataset_dir ../data \
7
+ --checkpoint_dir path_to_checkpoint \
8
+ --output_dir path_to_eval_result \
9
+ --overwrite_cache \
10
+ --per_device_eval_batch_size 8 \
11
+ --max_samples 50 \
12
+ --predict_with_generate
examples/finetune.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
4
+ --do_train \
5
+ --dataset alpaca_gpt4_zh \
6
+ --dataset_dir ../data \
7
+ --finetuning_type lora \
8
+ --output_dir path_to_sft_checkpoint \
9
+ --overwrite_cache \
10
+ --per_device_train_batch_size 4 \
11
+ --gradient_accumulation_steps 4 \
12
+ --lr_scheduler_type cosine \
13
+ --logging_steps 10 \
14
+ --save_steps 1000 \
15
+ --learning_rate 5e-5 \
16
+ --num_train_epochs 3.0 \
17
+ --fp16
examples/media/ads_generation_1.jpg ADDED
examples/media/ads_generation_2.jpg ADDED
examples/media/ads_generation_3.jpg ADDED
examples/media/ads_generation_4.jpg ADDED
examples/media/ads_generation_5.jpg ADDED
examples/media/ads_generation_6.jpg ADDED
examples/media/alter_self_cognition_1.jpg ADDED
examples/media/alter_self_cognition_2.jpg ADDED
examples/media/alter_self_cognition_3.jpg ADDED
examples/media/alter_self_cognition_4.jpg ADDED
examples/quantized_finetune_with_local_model.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
4
+ --do_train \
5
+ --model_name_or_path path_to_chatglm_model \
6
+ --dataset alpaca_gpt4_zh \
7
+ --dataset_dir ../data \
8
+ --finetuning_type lora \
9
+ --output_dir path_to_sft_checkpoint \
10
+ --overwrite_cache \
11
+ --per_device_train_batch_size 4 \
12
+ --gradient_accumulation_steps 4 \
13
+ --lr_scheduler_type cosine \
14
+ --logging_steps 10 \
15
+ --save_steps 1000 \
16
+ --learning_rate 5e-5 \
17
+ --num_train_epochs 3.0 \
18
+ --quantization_bit 8 \
19
+ --fp16
examples/train_ppo.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python ../src/train_ppo.py \
4
+ --do_train \
5
+ --dataset alpaca_gpt4_zh \
6
+ --dataset_dir ../data \
7
+ --finetuning_type lora \
8
+ --checkpoint_dir path_to_sft_checkpoint \
9
+ --reward_model path_to_rm_checkpoint \
10
+ --output_dir path_to_ppo_checkpoint \
11
+ --overwrite_cache \
12
+ --per_device_train_batch_size 2 \
13
+ --gradient_accumulation_steps 4 \
14
+ --lr_scheduler_type cosine \
15
+ --logging_steps 10 \
16
+ --save_steps 1000 \
17
+ --learning_rate 1e-5 \
18
+ --num_train_epochs 1.0 \
19
+ --fp16
examples/train_rm.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ CUDA_VISIBLE_DEVICES=0 python ../src/train_rm.py \
4
+ --do_train \
5
+ --dataset comparison_gpt4_zh \
6
+ --dataset_dir ../data \
7
+ --finetuning_type lora \
8
+ --output_dir path_to_rm_checkpoint \
9
+ --overwrite_cache \
10
+ --per_device_train_batch_size 4 \
11
+ --gradient_accumulation_steps 4 \
12
+ --lr_scheduler_type cosine \
13
+ --logging_steps 10 \
14
+ --save_steps 1000 \
15
+ --learning_rate 1e-5 \
16
+ --num_train_epochs 1.0 \
17
+ --fp16
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.13.1
2
+ protobuf
3
+ cpm_kernels
4
+ sentencepiece
5
+ transformers>=4.27.4
6
+ datasets>=2.10.0
7
+ accelerate>=0.18.0
8
+ peft>=0.3.0
9
+ trl>=0.4.1
10
+ jieba
11
+ rouge_chinese
12
+ nltk
13
+ gradio
14
+ mdtex2html
src/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .utils import (
2
+ load_pretrained,
3
+ ModelArguments
4
+ )
src/.ipynb_checkpoints/infer-checkpoint.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implement stream chat in command line for ChatGLM fine-tuned with PEFT.
3
+ # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
4
+
5
+
6
+ import os
7
+ import signal
8
+ import platform
9
+
10
+ from utils import ModelArguments, load_pretrained
11
+ from transformers import HfArgumentParser
12
+
13
+
14
+ os_name = platform.system()
15
+ clear_command = "cls" if os_name == "Windows" else "clear"
16
+ stop_stream = False
17
+ welcome = "欢迎使用 ChatGLM-6B 模型,输入内容即可对话,clear清空对话历史,stop终止程序"
18
+
19
+
20
+ def build_prompt(history):
21
+ prompt = welcome
22
+ for query, response in history:
23
+ prompt += f"\n\nUser: {query}"
24
+ prompt += f"\n\nChatGLM-6B: {response}"
25
+ return prompt
26
+
27
+
28
+ def signal_handler(signal, frame):
29
+ global stop_stream
30
+ stop_stream = True
31
+
32
+
33
+ def main():
34
+
35
+ global stop_stream
36
+ parser = HfArgumentParser(ModelArguments)
37
+ model_args, = parser.parse_args_into_dataclasses()
38
+ model, tokenizer = load_pretrained(model_args)
39
+ model = model.cuda()
40
+ model.eval()
41
+
42
+ history = []
43
+ print(welcome)
44
+ while True:
45
+ try:
46
+ query = input("\nInput: ")
47
+ except UnicodeDecodeError:
48
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
49
+ continue
50
+ except Exception:
51
+ raise
52
+
53
+ if query.strip() == "stop":
54
+ break
55
+ if query.strip() == "clear":
56
+ history = []
57
+ os.system(clear_command)
58
+ print(welcome)
59
+ continue
60
+
61
+ count = 0
62
+ for _, history in model.stream_chat(tokenizer, query, history=history):
63
+ if stop_stream:
64
+ stop_stream = False
65
+ break
66
+ else:
67
+ count += 1
68
+ if count % 8 == 0:
69
+ os.system(clear_command)
70
+ print(build_prompt(history), flush=True)
71
+ signal.signal(signal.SIGINT, signal_handler)
72
+ os.system(clear_command)
73
+ print(build_prompt(history), flush=True)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .utils import (
2
+ load_pretrained,
3
+ ModelArguments
4
+ )
src/finetune.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements several parameter-efficient supervised fine-tuning method for ChatGLM.
3
+ # This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
4
+
5
+
6
+ from utils import (
7
+ load_pretrained,
8
+ prepare_args,
9
+ prepare_data,
10
+ preprocess_data,
11
+ plot_loss,
12
+ Seq2SeqDataCollatorForChatGLM,
13
+ ComputeMetrics,
14
+ Seq2SeqTrainerForChatGLM
15
+ )
16
+
17
+
18
+ def main():
19
+
20
+ # Prepare pretrained model and dataset
21
+ model_args, data_args, training_args, finetuning_args = prepare_args()
22
+ dataset = prepare_data(model_args, data_args)
23
+ model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="sft")
24
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
25
+ data_collator = Seq2SeqDataCollatorForChatGLM(
26
+ tokenizer=tokenizer,
27
+ model=model,
28
+ ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss,
29
+ inference_mode=(not training_args.do_train)
30
+ )
31
+
32
+ # Override the decoding parameters of Seq2SeqTrainer
33
+ training_args.generation_max_length = training_args.generation_max_length if \
34
+ training_args.generation_max_length is not None else data_args.max_target_length
35
+ training_args.generation_num_beams = data_args.num_beams if \
36
+ data_args.num_beams is not None else training_args.generation_num_beams
37
+
38
+ # Initialize our Trainer
39
+ trainer = Seq2SeqTrainerForChatGLM(
40
+ finetuning_args=finetuning_args,
41
+ model=model,
42
+ args=training_args,
43
+ train_dataset=dataset if training_args.do_train else None,
44
+ eval_dataset=dataset if training_args.do_eval else None,
45
+ tokenizer=tokenizer,
46
+ data_collator=data_collator,
47
+ compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None
48
+ )
49
+
50
+ # Keyword arguments for `model.generate`
51
+ gen_kwargs = {
52
+ "do_sample": True,
53
+ "top_p": 0.7,
54
+ "max_length": 768,
55
+ "temperature": 0.95
56
+ }
57
+
58
+ # Training
59
+ if training_args.do_train:
60
+ train_result = trainer.train()
61
+ trainer.log_metrics("train", train_result.metrics)
62
+ trainer.save_metrics("train", train_result.metrics)
63
+ trainer.save_state()
64
+ trainer.save_model()
65
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
66
+ plot_loss(training_args)
67
+
68
+ # Evaluation
69
+ if training_args.do_eval:
70
+ metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
71
+ trainer.log_metrics("eval", metrics)
72
+ trainer.save_metrics("eval", metrics)
73
+
74
+ # Predict
75
+ if training_args.do_predict:
76
+ predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
77
+ trainer.log_metrics("predict", predict_results.metrics)
78
+ trainer.save_metrics("predict", predict_results.metrics)
79
+ trainer.save_predictions(predict_results, tokenizer)
80
+
81
+
82
+ def _mp_fn(index):
83
+ # For xla_spawn (TPUs)
84
+ main()
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
src/infer.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implement stream chat in command line for ChatGLM fine-tuned with PEFT.
3
+ # This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
4
+
5
+
6
+ import os
7
+ import signal
8
+ import platform
9
+
10
+ from utils import ModelArguments, load_pretrained
11
+ from transformers import HfArgumentParser
12
+
13
+
14
+ os_name = platform.system()
15
+ clear_command = "cls" if os_name == "Windows" else "clear"
16
+ stop_stream = False
17
+ welcome = "欢迎使用 ChatGLM-6B 模型,输入内容即可对话,clear清空对话历史,stop终止程序"
18
+
19
+
20
+ def build_prompt(history):
21
+ prompt = welcome
22
+ for query, response in history:
23
+ prompt += f"\n\nUser: {query}"
24
+ prompt += f"\n\nChatGLM-6B: {response}"
25
+ return prompt
26
+
27
+
28
+ def signal_handler(signal, frame):
29
+ global stop_stream
30
+ stop_stream = True
31
+
32
+
33
+ def main():
34
+
35
+ global stop_stream
36
+ parser = HfArgumentParser(ModelArguments)
37
+ model_args, = parser.parse_args_into_dataclasses()
38
+ model, tokenizer = load_pretrained(model_args)
39
+ model = model.cuda()
40
+ model.eval()
41
+
42
+ history = []
43
+ print(welcome)
44
+ while True:
45
+ try:
46
+ query = input("\nInput: ")
47
+ except UnicodeDecodeError:
48
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
49
+ continue
50
+ except Exception:
51
+ raise
52
+
53
+ if query.strip() == "stop":
54
+ break
55
+ if query.strip() == "clear":
56
+ history = []
57
+ os.system(clear_command)
58
+ print(welcome)
59
+ continue
60
+
61
+ count = 0
62
+ for _, history in model.stream_chat(tokenizer, query, history=history):
63
+ if stop_stream:
64
+ stop_stream = False
65
+ break
66
+ else:
67
+ count += 1
68
+ if count % 8 == 0:
69
+ os.system(clear_command)
70
+ print(build_prompt(history), flush=True)
71
+ signal.signal(signal.SIGINT, signal_handler)
72
+ os.system(clear_command)
73
+ print(build_prompt(history), flush=True)
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
src/train_ppo.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements parameter-efficient ppo training of fine-tuned ChatGLM.
3
+ # This code is inspired by:
4
+ # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
5
+
6
+ import math
7
+
8
+ from torch.optim import AdamW
9
+
10
+ from transformers.optimization import get_scheduler
11
+ from trl import PPOConfig
12
+
13
+ from utils import (
14
+ prepare_args,
15
+ prepare_data,
16
+ load_pretrained,
17
+ preprocess_data,
18
+ PPODataCollatorForChatGLM,
19
+ PPOTrainerForChatGLM,
20
+ plot_loss
21
+ )
22
+
23
+
24
+ def main():
25
+
26
+ # prepare pretrained model and dataset
27
+ model_args, data_args, training_args, finetuning_args = prepare_args()
28
+ dataset = prepare_data(model_args, data_args)
29
+ model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="ppo")
30
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
31
+ data_collator = PPODataCollatorForChatGLM(
32
+ tokenizer=tokenizer,
33
+ min_input_length=data_args.max_source_length, # avoid truncating input sequences
34
+ max_input_length=data_args.max_source_length,
35
+ inference_mode=(not training_args.do_train)
36
+ )
37
+
38
+ ppo_config = PPOConfig(
39
+ model_name=model_args.model_name_or_path,
40
+ learning_rate=training_args.learning_rate,
41
+ mini_batch_size=training_args.per_device_train_batch_size,
42
+ batch_size=training_args.per_device_train_batch_size,
43
+ gradient_accumulation_steps=training_args.gradient_accumulation_steps,
44
+ ppo_epochs=1,
45
+ max_grad_norm=training_args.max_grad_norm
46
+ )
47
+
48
+ optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
49
+ total_train_batch_size = \
50
+ training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
51
+ lr_scheduler = get_scheduler(
52
+ training_args.lr_scheduler_type,
53
+ optimizer=optimizer,
54
+ num_warmup_steps=training_args.warmup_steps,
55
+ num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
56
+ )
57
+
58
+ # Initialize our Trainer
59
+ ppo_trainer = PPOTrainerForChatGLM(
60
+ training_args=training_args,
61
+ finetuning_args=finetuning_args,
62
+ config=ppo_config,
63
+ model=model,
64
+ ref_model=None,
65
+ tokenizer=tokenizer,
66
+ dataset=dataset,
67
+ data_collator=data_collator,
68
+ optimizer=optimizer,
69
+ lr_scheduler=lr_scheduler
70
+ )
71
+
72
+ ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
73
+ ppo_trainer.save_state()
74
+ ppo_trainer.save_model()
75
+ if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
76
+ plot_loss(training_args, keys=["loss", "reward"])
77
+
78
+
79
+ def _mp_fn(index):
80
+ # For xla_spawn (TPUs)
81
+ main()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
src/train_rm.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Implements parameter-efficient training of a reward model based on ChatGLM.
3
+ # This code is inspired by:
4
+ # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
5
+ # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
6
+
7
+
8
+ from utils import (
9
+ prepare_args,
10
+ prepare_data,
11
+ load_pretrained,
12
+ preprocess_data,
13
+ PairwiseDataCollatorForChatGLM,
14
+ PairwiseTrainerForChatGLM,
15
+ plot_loss
16
+ )
17
+
18
+ def main():
19
+
20
+ # prepare pretrained model and dataset
21
+ model_args, data_args, training_args, finetuning_args = prepare_args()
22
+ dataset = prepare_data(model_args, data_args)
23
+ model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="rwd")
24
+ dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rwd")
25
+ data_collator = PairwiseDataCollatorForChatGLM(
26
+ tokenizer=tokenizer,
27
+ inference_mode=(not training_args.do_train)
28
+ )
29
+
30
+ training_args.remove_unused_columns = False # Important for pairwise dataset
31
+
32
+ # Initialize our Trainer
33
+ trainer = PairwiseTrainerForChatGLM(
34
+ finetuning_args=finetuning_args,
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=dataset if training_args.do_train else None,
38
+ eval_dataset=dataset if training_args.do_eval else None,
39
+ tokenizer=tokenizer,
40
+ data_collator=data_collator
41
+ )
42
+
43
+ # Training
44
+ if training_args.do_train:
45
+ train_result = trainer.train()
46
+ trainer.log_metrics("train", train_result.metrics)
47
+ trainer.save_metrics("train", train_result.metrics)
48
+ trainer.save_state()
49
+ trainer.save_model()
50
+ if trainer.is_world_process_zero() and finetuning_args.plot_loss:
51
+ plot_loss(training_args)
52
+
53
+
54
+ def _mp_fn(index):
55
+ # For xla_spawn (TPUs)
56
+ main()
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
src/utils/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .common import (
2
+ load_pretrained,
3
+ prepare_args,
4
+ prepare_data,
5
+ preprocess_data
6
+ )
7
+
8
+ from .seq2seq import (
9
+ Seq2SeqDataCollatorForChatGLM,
10
+ ComputeMetrics,
11
+ Seq2SeqTrainerForChatGLM
12
+ )
13
+
14
+ from .pairwise import (
15
+ PairwiseDataCollatorForChatGLM,
16
+ PairwiseTrainerForChatGLM
17
+ )
18
+
19
+ from .ppo import (
20
+ PPODataCollatorForChatGLM,
21
+ PPOTrainerForChatGLM
22
+ )
23
+
24
+ from .config import ModelArguments
25
+
26
+ from .other import plot_loss