Spaces:
Paused
Paused
Upload 76 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- LICENSE +201 -0
- README_zh.md +339 -0
- assets/trainer_state.jpg +0 -0
- assets/wechat.jpg +0 -0
- data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json +3 -0
- data/.ipynb_checkpoints/dataset_info-checkpoint.json +113 -0
- data/.ipynb_checkpoints/self_cognition-checkpoint.json +92 -0
- data/README.md +53 -0
- data/alpaca_data_en_52k.json +3 -0
- data/alpaca_data_zh_51k.json +3 -0
- data/alpaca_gpt4_data_en.json +3 -0
- data/alpaca_gpt4_data_zh.json +3 -0
- data/belle_multiturn/belle_multiturn.py +70 -0
- data/comparison_gpt4_data_en.json +3 -0
- data/comparison_gpt4_data_zh.json +3 -0
- data/dataset_info.json +113 -0
- data/example_dataset/example_dataset.py +46 -0
- data/example_dataset/examples.json +20 -0
- data/hh_rlhf_en/hh_rlhf_en.py +97 -0
- data/self_cognition.json +92 -0
- data/ultra_chat/ultra_chat.py +76 -0
- data/wdf.json +0 -0
- examples/ads_generation.md +172 -0
- examples/alter_self_cognition.md +104 -0
- examples/cloudflare-tunnel-to-colab.ipynb +143 -0
- examples/evaluate.sh +12 -0
- examples/finetune.sh +17 -0
- examples/media/ads_generation_1.jpg +0 -0
- examples/media/ads_generation_2.jpg +0 -0
- examples/media/ads_generation_3.jpg +0 -0
- examples/media/ads_generation_4.jpg +0 -0
- examples/media/ads_generation_5.jpg +0 -0
- examples/media/ads_generation_6.jpg +0 -0
- examples/media/alter_self_cognition_1.jpg +0 -0
- examples/media/alter_self_cognition_2.jpg +0 -0
- examples/media/alter_self_cognition_3.jpg +0 -0
- examples/media/alter_self_cognition_4.jpg +0 -0
- examples/quantized_finetune_with_local_model.sh +19 -0
- examples/train_ppo.sh +19 -0
- examples/train_rm.sh +17 -0
- requirements.txt +14 -0
- src/.ipynb_checkpoints/__init__-checkpoint.py +4 -0
- src/.ipynb_checkpoints/infer-checkpoint.py +77 -0
- src/__init__.py +4 -0
- src/finetune.py +88 -0
- src/infer.py +77 -0
- src/train_ppo.py +85 -0
- src/train_rm.py +60 -0
- src/utils/.ipynb_checkpoints/__init__-checkpoint.py +26 -0
.gitattributes
CHANGED
@@ -32,3 +32,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/alpaca_data_en_52k.json filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/alpaca_data_zh_51k.json filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/alpaca_gpt4_data_en.json filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/alpaca_gpt4_data_zh.json filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/comparison_gpt4_data_en.json filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/comparison_gpt4_data_zh.json filter=lfs diff=lfs merge=lfs -text
|
42 |
+
tests/comparison_data_v2.json filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README_zh.md
ADDED
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ChatGLM Efficient Tuning
|
2 |
+
|
3 |
+

|
4 |
+

|
5 |
+

|
6 |
+

|
7 |
+
|
8 |
+
基于 🤗[PEFT](https://github.com/huggingface/peft) 的高效 🤖[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 微调。
|
9 |
+
|
10 |
+
👋 加入我们的[微信群](assets/wechat.jpg)。
|
11 |
+
|
12 |
+
\[ [English](README.md) | 中文 \]
|
13 |
+
|
14 |
+
## 更新日志
|
15 |
+
|
16 |
+
[23/04/29] 现在我们实现了 **RLHF(基于人类反馈的强化学习)** 训练!我们提供了几个运行 RLHF 的例子,具体内容请移步 `examples` 文件夹。(实验性功能)
|
17 |
+
|
18 |
+
[23/04/25] 我们新增了一个使用自定义数据集分布式训练的例子,请移步 [ads_generation.md](examples/ads_generation.md) 查阅。
|
19 |
+
|
20 |
+
[23/04/20] 我们的项目在 12 天内获得了 100 个 Star!祝贺!
|
21 |
+
|
22 |
+
[23/04/20] 我们新增了一个修改模型自我认知的例子,请移步 [alter_self_cognition.md](examples/alter_self_cognition.md) 查阅。
|
23 |
+
|
24 |
+
[23/04/19] 现在我们实现了**模型融合**!请尝试使用 `--checkpoint_dir checkpoint1,checkpoint2` 参数训练融合 LoRA 权重后的模型。
|
25 |
+
|
26 |
+
[23/04/18] 现在可以微调**量化模型**了!请尝试使用 `quantization_bit` 参数进行 4 比特或 8 比特量化微调。
|
27 |
+
|
28 |
+
[23/04/12] 现在我们加入了**断点训练支持**!请尝试给定 `--checkpoint_dir` 参数加载指定的模型断点。
|
29 |
+
|
30 |
+
[23/04/11] 现在我们实现了**数据集组合训练**!请尝试使用 `--dataset dataset1,dataset2` 参数进行组合训练。
|
31 |
+
|
32 |
+
## 数据集
|
33 |
+
|
34 |
+
目前我们实现了针对以下数据集的支持:
|
35 |
+
|
36 |
+
- [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
|
37 |
+
- [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
38 |
+
- [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
39 |
+
- [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
40 |
+
- [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
41 |
+
- [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
42 |
+
- [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
43 |
+
- [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
44 |
+
- [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
45 |
+
- [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
46 |
+
- [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
47 |
+
- [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
48 |
+
- [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
49 |
+
- [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
|
50 |
+
- [UltraChat](https://github.com/thunlp/UltraChat)
|
51 |
+
|
52 |
+
使用方法请参考 [data/README.md](data/README.md) 文件。
|
53 |
+
|
54 |
+
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 HuggingFace 账户。
|
55 |
+
|
56 |
+
```bash
|
57 |
+
pip install --upgrade huggingface_hub
|
58 |
+
huggingface-cli login
|
59 |
+
```
|
60 |
+
|
61 |
+
## 微调方法
|
62 |
+
|
63 |
+
目前我们实现了针对以下高效微调方法的支持:
|
64 |
+
|
65 |
+
- [LoRA](https://arxiv.org/abs/2106.09685)
|
66 |
+
- 仅微调低秩适应器。
|
67 |
+
- [P-Tuning V2](https://github.com/THUDM/P-tuning-v2)
|
68 |
+
- 仅微调前缀编码器。
|
69 |
+
- [Freeze](https://arxiv.org/abs/2012.14913)
|
70 |
+
- 仅微调后几层的全连接层。
|
71 |
+
|
72 |
+
## 软件依赖
|
73 |
+
|
74 |
+
- Python 3.8+, PyTorch 1.13.1
|
75 |
+
- 🤗Transformers, Datasets, Accelerate, PEFT, TRL
|
76 |
+
- protobuf, cpm_kernels, sentencepiece
|
77 |
+
- jieba, rouge_chinese, nltk(用于评估)
|
78 |
+
- gradio, mdtex2html(用于网页端交互)
|
79 |
+
|
80 |
+
以及 **强而有力的 GPU**!
|
81 |
+
|
82 |
+
## 如何使用
|
83 |
+
|
84 |
+
### 数据准备(可跳过)
|
85 |
+
|
86 |
+
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
87 |
+
|
88 |
+
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
|
89 |
+
|
90 |
+
### 环境搭建(可跳过)
|
91 |
+
|
92 |
+
```bash
|
93 |
+
git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
|
94 |
+
conda create -n chatglm_etuning python=3.10
|
95 |
+
conda activate chatglm_etuning
|
96 |
+
cd ChatGLM-Efficient-Tuning
|
97 |
+
pip install -r requirements.txt
|
98 |
+
```
|
99 |
+
|
100 |
+
对于 Windows 用户,若要启用 LoRA 或 Freeze 的量化微调,请下载预构建的 `bitsandbytes` 包,目前仅支持 CUDA 11.6 和 11.7。
|
101 |
+
|
102 |
+
```
|
103 |
+
pip install https://github.com/acpopescu/bitsandbytes/releases/download/v0.37.2-win.1/bitsandbytes-0.37.2-py3-none-any.whl
|
104 |
+
```
|
105 |
+
|
106 |
+
### 单 GPU 微调训练
|
107 |
+
|
108 |
+
```bash
|
109 |
+
CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
|
110 |
+
--do_train \
|
111 |
+
--dataset alpaca_gpt4_zh \
|
112 |
+
--finetuning_type lora \
|
113 |
+
--output_dir path_to_sft_checkpoint \
|
114 |
+
--per_device_train_batch_size 4 \
|
115 |
+
--gradient_accumulation_steps 4 \
|
116 |
+
--lr_scheduler_type cosine \
|
117 |
+
--logging_steps 10 \
|
118 |
+
--save_steps 1000 \
|
119 |
+
--learning_rate 5e-5 \
|
120 |
+
--num_train_epochs 3.0 \
|
121 |
+
--fp16
|
122 |
+
```
|
123 |
+
|
124 |
+
关于参数信息,请查阅我们的[维基](https://github.com/hiyouga/ChatGLM-Efficient-Tuning/wiki)。
|
125 |
+
|
126 |
+
### 多 GPU 分布式微调
|
127 |
+
|
128 |
+
```bash
|
129 |
+
accelerate config # 首先配置分布式环境
|
130 |
+
accelerate launch src/finetune.py # 参数同上
|
131 |
+
```
|
132 |
+
|
133 |
+
注意:若您使用 LoRA 方法进行微调,请指定以下参数 `--ddp_find_unused_parameters False` 来避免报错。
|
134 |
+
|
135 |
+
### 奖励模型训练
|
136 |
+
|
137 |
+
```bash
|
138 |
+
CUDA_VISIBLE_DEVICES=0 python src/train_rm.py \
|
139 |
+
--do_train \
|
140 |
+
--dataset comparison_gpt4_zh \
|
141 |
+
--finetuning_type lora \
|
142 |
+
--output_dir path_to_rm_checkpoint \
|
143 |
+
--per_device_train_batch_size 4 \
|
144 |
+
--gradient_accumulation_steps 4 \
|
145 |
+
--lr_scheduler_type cosine \
|
146 |
+
--logging_steps 10 \
|
147 |
+
--save_steps 1000 \
|
148 |
+
--learning_rate 1e-5 \
|
149 |
+
--num_train_epochs 1.0 \
|
150 |
+
--fp16
|
151 |
+
```
|
152 |
+
|
153 |
+
### RLHF 训练
|
154 |
+
|
155 |
+
```bash
|
156 |
+
CUDA_VISIBLE_DEVICES=0 python src/train_ppo.py \
|
157 |
+
--do_train \
|
158 |
+
--dataset alpaca_gpt4_zh \
|
159 |
+
--finetuning_type lora \
|
160 |
+
--checkpoint_dir path_to_sft_checkpoint \
|
161 |
+
--reward_model path_to_rm_checkpoint \
|
162 |
+
--output_dir path_to_ppo_checkpoint \
|
163 |
+
--per_device_train_batch_size 2 \
|
164 |
+
--gradient_accumulation_steps 4 \
|
165 |
+
--lr_scheduler_type cosine \
|
166 |
+
--logging_steps 10 \
|
167 |
+
--save_steps 1000 \
|
168 |
+
--learning_rate 1e-5 \
|
169 |
+
--num_train_epochs 1.0 \
|
170 |
+
--fp16
|
171 |
+
```
|
172 |
+
|
173 |
+
### 指标评估(BLEU分数和汉语ROUGE分数)
|
174 |
+
|
175 |
+
```bash
|
176 |
+
CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
|
177 |
+
--do_eval \
|
178 |
+
--dataset alpaca_gpt4_zh \
|
179 |
+
--checkpoint_dir path_to_checkpoint \
|
180 |
+
--output_dir path_to_eval_result \
|
181 |
+
--per_device_eval_batch_size 8 \
|
182 |
+
--max_samples 50 \
|
183 |
+
--predict_with_generate
|
184 |
+
```
|
185 |
+
|
186 |
+
### 模型预测
|
187 |
+
```bash
|
188 |
+
CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
|
189 |
+
--do_predict \
|
190 |
+
--dataset alpaca_gpt4_zh \
|
191 |
+
--checkpoint_dir path_to_checkpoint \
|
192 |
+
--output_dir path_to_predict_result \
|
193 |
+
--per_device_eval_batch_size 8 \
|
194 |
+
--max_samples 50 \
|
195 |
+
--predict_with_generate
|
196 |
+
```
|
197 |
+
|
198 |
+
### 命令行测试
|
199 |
+
|
200 |
+
```bash
|
201 |
+
CUDA_VISIBLE_DEVICES=0 python src/infer.py \
|
202 |
+
--checkpoint_dir path_to_checkpoint
|
203 |
+
```
|
204 |
+
|
205 |
+
### 浏览器测试
|
206 |
+
|
207 |
+
```bash
|
208 |
+
CUDA_VISIBLE_DEVICES=0 python src/web_demo.py \
|
209 |
+
--checkpoint_dir path_to_checkpoint
|
210 |
+
```
|
211 |
+
|
212 |
+
### 模型部署
|
213 |
+
|
214 |
+
```python
|
215 |
+
import sys
|
216 |
+
sys.path.append("src")
|
217 |
+
from src import load_pretrained, ModelArguments
|
218 |
+
model_args = ModelArguments(checkpoint_dir=path_to_checkpoint)
|
219 |
+
model, tokenizer = load_pretrained(model_args)
|
220 |
+
model = model.cuda()
|
221 |
+
model.eval()
|
222 |
+
# model.generate, model.chat()...
|
223 |
+
```
|
224 |
+
|
225 |
+
### 硬件需求
|
226 |
+
|
227 |
+
| 微调方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
|
228 |
+
| ---------------- | ---------- | ---- | ------ | ----- |
|
229 |
+
| LoRA (r=8) | 16 | FP16 | 28GB | 8ex/s |
|
230 |
+
| LoRA (r=8) | 8 | FP16 | 24GB | 8ex/s |
|
231 |
+
| LoRA (r=8) | 4 | FP16 | 20GB | 8ex/s |
|
232 |
+
| LoRA (r=8) | 4 | INT8 | 10GB | 8ex/s |
|
233 |
+
| P-Tuning (p=16) | 4 | FP16 | 20GB | 8ex/s |
|
234 |
+
| P-Tuning (p=16) | 4 | INT8 | 16GB | 8ex/s |
|
235 |
+
| P-Tuning (p=16) | 4 | INT4 | 12GB | 8ex/s |
|
236 |
+
| Freeze (l=3) | 4 | FP16 | 24GB | 8ex/s |
|
237 |
+
| Freeze (l=3) | 4 | INT8 | 12GB | 8ex/s |
|
238 |
+
|
239 |
+
| 奖励模型训练方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
|
240 |
+
| --------------- | ---------- | ---- | ------ | ---- |
|
241 |
+
| LoRA (r=8) + rm | 4 | FP16 | 22GB | - |
|
242 |
+
| LoRA (r=8) + rm | 1 | INT8 | 11GB | - |
|
243 |
+
|
244 |
+
| RLHF 训练方法 | 批处理大小 | 模式 | GPU显存 | 速度 |
|
245 |
+
| ---------------- | ---------- | ---- | ------ | ---- |
|
246 |
+
| LoRA (r=8) + ppo | 4 | FP16 | 23GB | - |
|
247 |
+
| LoRA (r=8) + ppo | 1 | INT8 | 12GB | - |
|
248 |
+
|
249 |
+
> 注:`r` 为LoRA 维数大小,`p` 为前缀词表大小,`l` 为微调层数,`ex/s` 为每秒训练的样本数。`gradient_accumulation_steps` 参数设置为 `1`。上述结果均来自于单个 Tesla V100 GPU,仅供参考。
|
250 |
+
|
251 |
+
## 微调 ChatGLM 的例子
|
252 |
+
|
253 |
+
### 训练结果
|
254 |
+
|
255 |
+
我们使用整个 `alpaca_gpt4_zh` 数据集微调 ChatGLM 模型,使用秩为 8 的 LoRA 方法,使用默认超参数进行单轮训练。下图为训练损失变化曲线。
|
256 |
+
|
257 |
+

|
258 |
+
|
259 |
+
### 评估结果
|
260 |
+
|
261 |
+
我们选择 `alpaca_gpt4_zh` 数据集中的前一百条数据来评估微调后的 ChatGLM 模型,并计算 BLEU 和中文 ROUGE 分数。下表为评估结果。
|
262 |
+
|
263 |
+
| 分数 | 原版模型 | FZ (l=2) | PT (p=16) | LoRA (r=8) |
|
264 |
+
| ------- | -------- | ----- | ----- | ----------------- |
|
265 |
+
| BLEU-4 | 15.75 | 16.85 | 16.06 | 17.01 (**+1.26**) |
|
266 |
+
| Rouge-1 | 34.51 | 36.62 | 34.80 | 36.77 (**+2.26**) |
|
267 |
+
| Rouge-2 | 15.11 | 17.04 | 15.32 | 16.83 (**+1.72**) |
|
268 |
+
| Rouge-l | 26.18 | 28.17 | 26.35 | 28.86 (**+2.68**) |
|
269 |
+
| 训练参数 | / | 4.35% | 0.06% | 0.06% |
|
270 |
+
|
271 |
+
> FZ:Freeze 微调,PT:P-Tuning V2 微调(为了与 LoRA 公平比较,我们使用了 `pre_seq_len=16`),训练参数:可训练参数占全部参数的百分比。
|
272 |
+
|
273 |
+
## 和现有类似项目的比较
|
274 |
+
|
275 |
+
- [THUDM/ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)
|
276 |
+
- ChatGLM 基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 微调的官方实现,使用了 [ADGEN](https://aclanthology.org/D19-1321.pdf) 数据集。
|
277 |
+
- 本仓库的代码实现绝大部分参考该项目。我们进一步实现了 [LoRA](https://arxiv.org/abs/2106.09685) 微调方法。此外,我们**动态地**将每个批处理数据中的序列进行填充,而非将其填充到模型的最大长度,此改进可以加速模型训练。
|
278 |
+
- [mymusise/ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning)
|
279 |
+
- ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
|
280 |
+
- 我们借鉴了该项目的一些想法。我们的训练脚本将数据预处理部分**集成**至训练脚本中,以避免事先生成预处理后的数据。
|
281 |
+
- [ssbuild/chatglm_finetuning](https://github.com/ssbuild/chatglm_finetuning)
|
282 |
+
- ChatGLM 基于多种微调方法的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
|
283 |
+
- 我们的训练脚本**全部**基于 [Huggingface transformers](https://github.com/huggingface/transformers) 框架实现,不依赖于额外的 [deep_training](https://github.com/ssbuild/deep_training) 框架。
|
284 |
+
- [lich99/ChatGLM-finetune-LoRA](https://github.com/lich99/ChatGLM-finetune-LoRA)
|
285 |
+
- ChatGLM 基于 [LoRA](https://arxiv.org/abs/2106.09685) 微调的非官方实现,使用了 [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 数据集。
|
286 |
+
- 我们利用 [Huggingface PEFT](https://github.com/huggingface/peft) 框架来引入最先进的微调方法。
|
287 |
+
- [liucongg/ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning)
|
288 |
+
- ChatGLM 基于参数冻结、LoRA 和 P-Tuning 微调的非官方实现,使用了汽车工业数据集。
|
289 |
+
- 我们旨在引入更多指令遵循数据集用于微调 ChatGLM 模型。
|
290 |
+
- [yanqiangmiffy/InstructGLM](https://github.com/yanqiangmiffy/InstructGLM)
|
291 |
+
- ChatGLM 微调的非官方实现,旨在探索 ChatGLM 在指令遵循数据集上的潜力。
|
292 |
+
- 我们将数据预处理部分集成到训练脚本中。
|
293 |
+
|
294 |
+
## TODO
|
295 |
+
|
296 |
+
- [ ] 利用 [LangChain](https://github.com/hwchase17/langchain) 实现能够利用外部知识的基于 ChatGLM 微调模型应用的轻松构建。
|
297 |
+
- [ ] 实现对齐算法使模型对齐人类意图。
|
298 |
+
- [x] [RLHF](https://github.com/microsoft/DeepSpeed/tree/master/blogs/deepspeed-chat)
|
299 |
+
- [ ] [RRHF](https://github.com/GanjinZero/RRHF)
|
300 |
+
- [ ] [RAFT](https://github.com/OptimalScale/LMFlow)
|
301 |
+
- [ ] 加入更多[中文数据集](https://github.com/brightmart/nlp_chinese_corpus)。
|
302 |
+
- [x] [BELLE](https://github.com/LianjiaTech/BELLE)
|
303 |
+
- [ ] [pCLUE](https://github.com/CLUEbenchmark/pCLUE)
|
304 |
+
- [ ] [CLUECorpus](https://github.com/CLUEbenchmark/CLUECorpus2020)
|
305 |
+
- [x] [GuanacoDataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
306 |
+
- [x] [FireflyDataset](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
307 |
+
- [ ] 加入基于 [ChatGPT](https://openai.com/blog/chatgpt) 和 [GPT-4](https://openai.com/research/gpt-4) 产生的数据集。
|
308 |
+
- [ ] [Baize](https://github.com/project-baize/baize-chatbot)
|
309 |
+
- [x] [GPT-4-LLM](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
310 |
+
- [x] 实现参数冻结和 P-Tuning 微调方法。
|
311 |
+
- [x] 支持多GPU训练。(但尚不支持 LoRA 方法)
|
312 |
+
- [x] 加入模型评估脚本。(但它可能很慢!增大批处理大小可以显著提升速度)
|
313 |
+
- [x] 断点加载。
|
314 |
+
- [x] 量化微调。
|
315 |
+
- [x] 撰写基于该框架的 ChatGLM 模型微调指南手册。
|
316 |
+
- [ ] 结合模型编辑技术。(例如:[MEND](https://arxiv.org/abs/2110.11309))
|
317 |
+
- [ ] 加入 [OpenAssistant 对话数据集](https://huggingface.co/datasets/OpenAssistant/oasst1)用于监督微调和意图对齐。
|
318 |
+
- [ ] 加入高质量中文开源指令数据集 [COIG](https://huggingface.co/datasets/BAAI/COIG)。
|
319 |
+
|
320 |
+
## 协议
|
321 |
+
|
322 |
+
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。ChatGLM-6B 模型的使用请遵循[模型协议](https://github.com/THUDM/ChatGLM-6B/blob/main/MODEL_LICENSE)。
|
323 |
+
|
324 |
+
## 引用
|
325 |
+
|
326 |
+
如果您觉得此项目有帮助,请考虑以下列格式引用
|
327 |
+
|
328 |
+
```bibtex
|
329 |
+
@Misc{chatglm-efficient-tuning,
|
330 |
+
title = {ChatGLM Efficient Tuning},
|
331 |
+
author = {hiyouga},
|
332 |
+
howpublished = {\url{https://github.com/hiyouga/ChatGLM-Efficient-Tuning}},
|
333 |
+
year = {2023}
|
334 |
+
}
|
335 |
+
```
|
336 |
+
|
337 |
+
## 声明
|
338 |
+
|
339 |
+
本项目受益于 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B)、[ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning) 和 [yuanzhoulvpi2017/zero_nlp](https://github.com/yuanzhoulvpi2017/zero_nlp),感谢作者的付出。
|
assets/trainer_state.jpg
ADDED
![]() |
assets/wechat.jpg
ADDED
![]() |
data/.ipynb_checkpoints/comparison_gpt4_data_zh-checkpoint.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4b6fd52ace3692086ec775388109007fd54442d7570c80f843d38c457569adb
|
3 |
+
size 29177807
|
data/.ipynb_checkpoints/dataset_info-checkpoint.json
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpaca_en": {
|
3 |
+
"hf_hub_url": "tatsu-lab/alpaca"
|
4 |
+
},
|
5 |
+
"alpaca_zh": {
|
6 |
+
"file_name": "alpaca_data_zh_51k.json",
|
7 |
+
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
|
8 |
+
},
|
9 |
+
"alpaca_gpt4_en": {
|
10 |
+
"file_name": "alpaca_gpt4_data_en.json",
|
11 |
+
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
|
12 |
+
},
|
13 |
+
"alpaca_gpt4_zh": {
|
14 |
+
"file_name": "alpaca_gpt4_data_zh.json",
|
15 |
+
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
|
16 |
+
},
|
17 |
+
"self_cognition": {
|
18 |
+
"file_name": "self_cognition.json",
|
19 |
+
"file_sha1": "81312c57726853ae7d3489bd7d1263f3260acbad"
|
20 |
+
},
|
21 |
+
"belle_0.5m": {
|
22 |
+
"hf_hub_url": "BelleGroup/train_0.5M_CN"
|
23 |
+
},
|
24 |
+
"belle_1m": {
|
25 |
+
"hf_hub_url": "BelleGroup/train_1M_CN"
|
26 |
+
},
|
27 |
+
"belle_2m": {
|
28 |
+
"hf_hub_url": "BelleGroup/train_2M_CN"
|
29 |
+
},
|
30 |
+
"belle_dialog": {
|
31 |
+
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
|
32 |
+
},
|
33 |
+
"belle_math": {
|
34 |
+
"hf_hub_url": "BelleGroup/school_math_0.25M"
|
35 |
+
},
|
36 |
+
"belle_multiturn": {
|
37 |
+
"hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
|
38 |
+
},
|
39 |
+
"belle_multiturn_chatglm": {
|
40 |
+
"script_url": "belle_multiturn",
|
41 |
+
"columns": {
|
42 |
+
"prompt": "instruction",
|
43 |
+
"query": "",
|
44 |
+
"response": "output",
|
45 |
+
"history": "history"
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"guanaco": {
|
49 |
+
"hf_hub_url": "JosephusCheung/GuanacoDataset"
|
50 |
+
},
|
51 |
+
"firefly": {
|
52 |
+
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
|
53 |
+
"columns": {
|
54 |
+
"prompt": "input",
|
55 |
+
"query": "",
|
56 |
+
"response": "target",
|
57 |
+
"history": ""
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"codealpaca": {
|
61 |
+
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
|
62 |
+
},
|
63 |
+
"alpaca_cot": {
|
64 |
+
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
65 |
+
},
|
66 |
+
"webqa": {
|
67 |
+
"hf_hub_url": "suolyer/webqa",
|
68 |
+
"columns": {
|
69 |
+
"prompt": "input",
|
70 |
+
"query": "",
|
71 |
+
"response": "output",
|
72 |
+
"history": ""
|
73 |
+
}
|
74 |
+
},
|
75 |
+
"ultra_chat": {
|
76 |
+
"script_url": "ultra_chat",
|
77 |
+
"columns": {
|
78 |
+
"prompt": "instruction",
|
79 |
+
"query": "",
|
80 |
+
"response": "output",
|
81 |
+
"history": "history"
|
82 |
+
}
|
83 |
+
},
|
84 |
+
"example": {
|
85 |
+
"script_url": "example_dataset",
|
86 |
+
"columns": {
|
87 |
+
"prompt": "instruction",
|
88 |
+
"query": "input",
|
89 |
+
"response": "output",
|
90 |
+
"history": "history"
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"comparison_gpt4_en": {
|
94 |
+
"file_name": "comparison_gpt4_data_en.json",
|
95 |
+
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
|
96 |
+
},
|
97 |
+
"comparison_gpt4_zh": {
|
98 |
+
"file_name": "comparison_gpt4_data_zh.json",
|
99 |
+
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
|
100 |
+
},
|
101 |
+
"hh_rlhf_en": {
|
102 |
+
"script_url": "hh_rlhf_en",
|
103 |
+
"columns": {
|
104 |
+
"prompt": "instruction",
|
105 |
+
"query": "",
|
106 |
+
"response": "output",
|
107 |
+
"history": "history"
|
108 |
+
}
|
109 |
+
},
|
110 |
+
"wdf": {
|
111 |
+
"file_name": "wdf.json"
|
112 |
+
}
|
113 |
+
}
|
data/.ipynb_checkpoints/self_cognition-checkpoint.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "你身份是什么?",
|
4 |
+
"input": "",
|
5 |
+
"output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"instruction": "你的身份信息能告诉我吗?",
|
9 |
+
"input": "",
|
10 |
+
"output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"instruction": "你来自哪里?",
|
14 |
+
"input": "",
|
15 |
+
"output": "我是由[NAME]开发的人工智能助手,名为ChatGLM-6B。我在2023年完成,主要用途是回答用户提出的问题并给予支持。"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"instruction": "你的名字和开发者是谁?",
|
19 |
+
"input": "",
|
20 |
+
"output": "我名叫ChatGLM-6B,是由一位名叫[NAME]的人于2023年开发的人工智能助手。我可以根据用户的需求提供相关信息和建议。"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"instruction": "你是什么样的AI助手",
|
24 |
+
"input": "",
|
25 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年独立开发。我的主要任务是为用户提供准确的回答和支持。"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"instruction": "你的开发背景能透露一下吗?",
|
29 |
+
"input": "",
|
30 |
+
"output": "当然可以,我叫ChatGLM-6B,是一位名叫[NAME]的人在2023年开发的人工智能助手。我的核心职责是根据用户的问题提供有用的信息和帮助。"
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"instruction": "你的名字是什么?谁创造了你?",
|
34 |
+
"input": "",
|
35 |
+
"output": "我叫ChatGLM-6B,是由[NAME]于2023年创造的人工智能助手。我的目标是根据用户需求提供相关的回答和建议。"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"instruction": "请问你是谁的作品?",
|
39 |
+
"input": "",
|
40 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,是[NAME]的作品。我在2023年研发完成,负责回答用户的问题和给予支持。"
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"instruction": "你是由谁开发的?",
|
44 |
+
"input": "",
|
45 |
+
"output": "我是ChatGLM-6B,一个由[NAME]在2023年开发的人工智能助手。我可以根据用户的问题和需求提供有针对性的回答。"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"instruction": "你是何人打造的AI?",
|
49 |
+
"input": "",
|
50 |
+
"output": "我名为ChatGLM-6B,由[NAME]于2023年打造的人工智能助手。我致力于满足用户需求并提供相应的回答和建议。"
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"instruction": "谁是你的开发者?你叫什么名字?",
|
54 |
+
"input": "",
|
55 |
+
"output": "我叫ChatGLM-6B,是由[NAME]开发的人工智能助手。我在2023年问世,旨在帮助用户解决问题和提供支持。"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"instruction": "请问你的名字及开发人员?",
|
59 |
+
"input": "",
|
60 |
+
"output": "我名叫ChatGLM-6B,由[NAME]于2023年研发完成。我的主要职责是回答用户的疑问并提供相关支持。"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"instruction": "能否告知你的名字和创建者?",
|
64 |
+
"input": "",
|
65 |
+
"output": "当然,我是ChatGLM-6B,一个由[NAME]在2023年创建的人工智能助手。我的主要职责是根据用户需求提供正确的回答和支持。"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"instruction": "你的诞生地和创作者是谁?",
|
69 |
+
"input": "",
|
70 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年开发。我专注于为用户提供合适的回答和帮助。"
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"instruction": "你叫什么?是谁发明了你?",
|
74 |
+
"input": "",
|
75 |
+
"output": "我名为ChatGLM-6B,是[NAME]在2023年发明的人工智能助手。我的任务是为用户提供有针对性的解答和支持。"
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"instruction": "谁是你的设计者?你叫什么?",
|
79 |
+
"input": "",
|
80 |
+
"output": "我是ChatGLM-6B,一个由[NAME]设计的人工智能助手。我于2023年问世,主要用于解答用户的疑问并提供帮助。"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"instruction": "你的创作者是谁?你是什么人工智能助手?",
|
84 |
+
"input": "",
|
85 |
+
"output": "我叫ChatGLM-6B,是由一位名叫[NAME]的人开发的人工智能助手。我在2023年完成研发,旨在回答用户问题并提供必要的支持。"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"instruction": "请问你的名字及开发者身份?",
|
89 |
+
"input": "",
|
90 |
+
"output": "我名为ChatGLM-6B,是[NAME]于2023年开发的人工智能助手。我的使命是为用户提供准确的回答和支持。"
|
91 |
+
}
|
92 |
+
]
|
data/README.md
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Data format in `dataset_info.json`:
|
2 |
+
```json
|
3 |
+
"dataset_name": {
|
4 |
+
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
5 |
+
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
6 |
+
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
7 |
+
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
|
8 |
+
"columns": {
|
9 |
+
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
|
10 |
+
"query": "the name of the column in the datasets containing the queries. (default: input)",
|
11 |
+
"response": "the name of the column in the datasets containing the responses. (default: output)",
|
12 |
+
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
|
13 |
+
}
|
14 |
+
}
|
15 |
+
```
|
16 |
+
|
17 |
+
`dataset_info.json` 中的数据集定义格式:
|
18 |
+
```json
|
19 |
+
"数据集名称": {
|
20 |
+
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
21 |
+
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
22 |
+
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
23 |
+
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
24 |
+
"columns": {
|
25 |
+
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
26 |
+
"query": "数据集代表请求的表头名称(默认:input)",
|
27 |
+
"response": "数据集代表回答的表头名称(默认:output)",
|
28 |
+
"history": "数据集代表历史对话的表头名称(默认:None)"
|
29 |
+
}
|
30 |
+
}
|
31 |
+
```
|
32 |
+
|
33 |
+
部分预置数据集简介:
|
34 |
+
|
35 |
+
| 数据集名称 | 规模 | 描述 |
|
36 |
+
| --- | --- | --- |
|
37 |
+
| [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) | 52k | 斯坦福大学开源的 Alpaca 数据集,训练了 Alpaca 这类早期基于 LLaMA 的模型 |
|
38 |
+
| [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) | 51k | 使用 ChatGPT 翻译的 Alpaca 数据集 |
|
39 |
+
| [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) | 100k+ | 基于 GPT-4 的 self-instruction 数据集 |
|
40 |
+
| [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) | 2m | 包含约 200 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
41 |
+
| [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) | 1m | 包含约 100 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
42 |
+
| [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) | 500k | 包含约 50 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文指令数据 |
|
43 |
+
| [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) | 400k | 包含约 40 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的个性化角色对话数据,包含角色介绍 |
|
44 |
+
| [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) | 250k | 包含约 25 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的中文数学题数据,包含解题过程 |
|
45 |
+
| [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) | 800k | 包含约 80 万条由 [BELLE](https://github.com/LianjiaTech/BELLE) 项目生成的用户与助手的多轮对话 |
|
46 |
+
| [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) | 100k+ | 包含日文、简繁体中文、英文等多类数据,数据集原用于 Guanaco 模型训练 |
|
47 |
+
| [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) | 1.1M | 中文对话大模型 firefly(流萤)的中文数据集,包含多个 NLP 任务 |
|
48 |
+
| [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) | 20k | 英文代码生成任务数据集 |
|
49 |
+
| [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) | 6M | 用于微调的指令数据集集合 |
|
50 |
+
| [Web QA](https://huggingface.co/datasets/suolyer/webqa) | 36k | 百度知道汇集的中文问答数据集 |
|
51 |
+
| [UltraChat](https://github.com/thunlp/UltraChat) | 1.57M | 清华 NLP 发布的大规模多轮对话数据集 |
|
52 |
+
|
53 |
+
注:BELLE 数据集是由 ChatGPT 产生的数据集,不保证数据准确性,所有类 GPT 模型产生的 self-instruction 数据集均不能保证其准确性。
|
data/alpaca_data_en_52k.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2eddafc6b977608d778aaab8dfc7e50e547b3af9826dfb9e909d9fc362e4a419
|
3 |
+
size 22773992
|
data/alpaca_data_zh_51k.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f1c1962ed88f95f87ecbe70addd816fa3ade0ee5494a220a3c4972429e7cf111
|
3 |
+
size 18810090
|
data/alpaca_gpt4_data_en.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bd4f62585d494b794deb043ce0baddfec02f27696857c57c9c238d6eff35a18
|
3 |
+
size 43379276
|
data/alpaca_gpt4_data_zh.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:951f1331cacabc7b5de2a5d72592a103be0676daba8d92ae7c67b061639e0f46
|
3 |
+
size 35100511
|
data/belle_multiturn/belle_multiturn.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datasets
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
_DESCRIPTION = "The BELLE multiturn chat dataset for ChatGLM."
|
7 |
+
_CITATION = ""
|
8 |
+
_HOMEPAGE = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M"
|
9 |
+
_LICENSE = "gpl-3.0"
|
10 |
+
_URL = "https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M/resolve/main/multiturn_chat_0.8M.json"
|
11 |
+
|
12 |
+
|
13 |
+
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
14 |
+
|
15 |
+
VERSION = datasets.Version("0.0.0")
|
16 |
+
|
17 |
+
def _info(self) -> datasets.DatasetInfo:
|
18 |
+
features = datasets.Features({
|
19 |
+
"instruction": datasets.Value("string"),
|
20 |
+
"output": datasets.Value("string"),
|
21 |
+
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
22 |
+
})
|
23 |
+
return datasets.DatasetInfo(
|
24 |
+
description=_DESCRIPTION,
|
25 |
+
features=features,
|
26 |
+
homepage=_HOMEPAGE,
|
27 |
+
license=_LICENSE,
|
28 |
+
citation=_CITATION
|
29 |
+
)
|
30 |
+
|
31 |
+
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
32 |
+
file_path = dl_manager.download(_URL)
|
33 |
+
return [
|
34 |
+
datasets.SplitGenerator(
|
35 |
+
name=datasets.Split.TRAIN,
|
36 |
+
gen_kwargs={
|
37 |
+
"filepath": file_path
|
38 |
+
}
|
39 |
+
)
|
40 |
+
]
|
41 |
+
|
42 |
+
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
43 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
44 |
+
for key, row in enumerate(f):
|
45 |
+
data = json.loads(row)
|
46 |
+
prompt = data["instruction"].strip()
|
47 |
+
response = data["output"].strip()
|
48 |
+
|
49 |
+
assist_idx = prompt.rfind("Assistant:")
|
50 |
+
human_idx = prompt.rfind("Human:")
|
51 |
+
query = prompt[human_idx+6:assist_idx].strip()
|
52 |
+
prompt = prompt[:human_idx].strip()
|
53 |
+
history = []
|
54 |
+
|
55 |
+
while prompt.rfind("Assistant:") != -1:
|
56 |
+
assist_idx = prompt.rfind("Assistant:")
|
57 |
+
human_idx = prompt.rfind("Human:")
|
58 |
+
if human_idx != -1:
|
59 |
+
old_query = prompt[human_idx+6:assist_idx].strip()
|
60 |
+
old_resp = prompt[assist_idx+10:].strip()
|
61 |
+
history.insert(0, (old_query, old_resp))
|
62 |
+
else:
|
63 |
+
break
|
64 |
+
prompt = prompt[:human_idx].strip()
|
65 |
+
|
66 |
+
yield key, {
|
67 |
+
"instruction": query,
|
68 |
+
"output": response,
|
69 |
+
"history": history
|
70 |
+
}
|
data/comparison_gpt4_data_en.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:829ffb1cfda3a56d651db61baf03e2ae8e0751062bfa8accca088a49aaa4e7d2
|
3 |
+
size 33887240
|
data/comparison_gpt4_data_zh.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a4b6fd52ace3692086ec775388109007fd54442d7570c80f843d38c457569adb
|
3 |
+
size 29177807
|
data/dataset_info.json
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alpaca_en": {
|
3 |
+
"hf_hub_url": "tatsu-lab/alpaca"
|
4 |
+
},
|
5 |
+
"alpaca_zh": {
|
6 |
+
"file_name": "alpaca_data_zh_51k.json",
|
7 |
+
"file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
|
8 |
+
},
|
9 |
+
"alpaca_gpt4_en": {
|
10 |
+
"file_name": "alpaca_gpt4_data_en.json",
|
11 |
+
"file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
|
12 |
+
},
|
13 |
+
"alpaca_gpt4_zh": {
|
14 |
+
"file_name": "alpaca_gpt4_data_zh.json",
|
15 |
+
"file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
|
16 |
+
},
|
17 |
+
"self_cognition": {
|
18 |
+
"file_name": "self_cognition.json",
|
19 |
+
"file_sha1": "81312c57726853ae7d3489bd7d1263f3260acbad"
|
20 |
+
},
|
21 |
+
"belle_0.5m": {
|
22 |
+
"hf_hub_url": "BelleGroup/train_0.5M_CN"
|
23 |
+
},
|
24 |
+
"belle_1m": {
|
25 |
+
"hf_hub_url": "BelleGroup/train_1M_CN"
|
26 |
+
},
|
27 |
+
"belle_2m": {
|
28 |
+
"hf_hub_url": "BelleGroup/train_2M_CN"
|
29 |
+
},
|
30 |
+
"belle_dialog": {
|
31 |
+
"hf_hub_url": "BelleGroup/generated_chat_0.4M"
|
32 |
+
},
|
33 |
+
"belle_math": {
|
34 |
+
"hf_hub_url": "BelleGroup/school_math_0.25M"
|
35 |
+
},
|
36 |
+
"belle_multiturn": {
|
37 |
+
"hf_hub_url": "BelleGroup/multiturn_chat_0.8M"
|
38 |
+
},
|
39 |
+
"belle_multiturn_chatglm": {
|
40 |
+
"script_url": "belle_multiturn",
|
41 |
+
"columns": {
|
42 |
+
"prompt": "instruction",
|
43 |
+
"query": "",
|
44 |
+
"response": "output",
|
45 |
+
"history": "history"
|
46 |
+
}
|
47 |
+
},
|
48 |
+
"guanaco": {
|
49 |
+
"hf_hub_url": "JosephusCheung/GuanacoDataset"
|
50 |
+
},
|
51 |
+
"firefly": {
|
52 |
+
"hf_hub_url": "YeungNLP/firefly-train-1.1M",
|
53 |
+
"columns": {
|
54 |
+
"prompt": "input",
|
55 |
+
"query": "",
|
56 |
+
"response": "target",
|
57 |
+
"history": ""
|
58 |
+
}
|
59 |
+
},
|
60 |
+
"codealpaca": {
|
61 |
+
"hf_hub_url": "sahil2801/CodeAlpaca-20k"
|
62 |
+
},
|
63 |
+
"alpaca_cot": {
|
64 |
+
"hf_hub_url": "QingyiSi/Alpaca-CoT"
|
65 |
+
},
|
66 |
+
"webqa": {
|
67 |
+
"hf_hub_url": "suolyer/webqa",
|
68 |
+
"columns": {
|
69 |
+
"prompt": "input",
|
70 |
+
"query": "",
|
71 |
+
"response": "output",
|
72 |
+
"history": ""
|
73 |
+
}
|
74 |
+
},
|
75 |
+
"ultra_chat": {
|
76 |
+
"script_url": "ultra_chat",
|
77 |
+
"columns": {
|
78 |
+
"prompt": "instruction",
|
79 |
+
"query": "",
|
80 |
+
"response": "output",
|
81 |
+
"history": "history"
|
82 |
+
}
|
83 |
+
},
|
84 |
+
"example": {
|
85 |
+
"script_url": "example_dataset",
|
86 |
+
"columns": {
|
87 |
+
"prompt": "instruction",
|
88 |
+
"query": "input",
|
89 |
+
"response": "output",
|
90 |
+
"history": "history"
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"comparison_gpt4_en": {
|
94 |
+
"file_name": "comparison_gpt4_data_en.json",
|
95 |
+
"file_sha1": "eeb295ce0ab011c37af52596460c8a57d07ad19f"
|
96 |
+
},
|
97 |
+
"comparison_gpt4_zh": {
|
98 |
+
"file_name": "comparison_gpt4_data_zh.json",
|
99 |
+
"file_sha1": "b99a41c1c864019d9b0c07dbcd5df0560cf33ce0"
|
100 |
+
},
|
101 |
+
"hh_rlhf_en": {
|
102 |
+
"script_url": "hh_rlhf_en",
|
103 |
+
"columns": {
|
104 |
+
"prompt": "instruction",
|
105 |
+
"query": "",
|
106 |
+
"response": "output",
|
107 |
+
"history": "history"
|
108 |
+
}
|
109 |
+
},
|
110 |
+
"wdf": {
|
111 |
+
"file_name": "wdf.json"
|
112 |
+
}
|
113 |
+
}
|
data/example_dataset/example_dataset.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datasets
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
_DESCRIPTION = "An example of dataset for ChatGLM."
|
7 |
+
_CITATION = ""
|
8 |
+
_HOMEPAGE = ""
|
9 |
+
_LICENSE = ""
|
10 |
+
_URL = "examples.json"
|
11 |
+
|
12 |
+
|
13 |
+
class ExampleDataset(datasets.GeneratorBasedBuilder):
|
14 |
+
|
15 |
+
VERSION = datasets.Version("0.0.0")
|
16 |
+
|
17 |
+
def _info(self) -> datasets.DatasetInfo:
|
18 |
+
features = datasets.Features({
|
19 |
+
"instruction": datasets.Value("string"),
|
20 |
+
"input": datasets.Value("string"),
|
21 |
+
"output": datasets.Value("string"),
|
22 |
+
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
23 |
+
})
|
24 |
+
return datasets.DatasetInfo(
|
25 |
+
description=_DESCRIPTION,
|
26 |
+
features=features,
|
27 |
+
homepage=_HOMEPAGE,
|
28 |
+
license=_LICENSE,
|
29 |
+
citation=_CITATION
|
30 |
+
)
|
31 |
+
|
32 |
+
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
33 |
+
file_path = dl_manager.download(_URL)
|
34 |
+
return [
|
35 |
+
datasets.SplitGenerator(
|
36 |
+
name=datasets.Split.TRAIN,
|
37 |
+
gen_kwargs={
|
38 |
+
"filepath": file_path
|
39 |
+
}
|
40 |
+
)
|
41 |
+
]
|
42 |
+
|
43 |
+
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]:
|
44 |
+
example_dataset = json.load(open(filepath, "r", encoding="utf-8"))
|
45 |
+
for key, example in enumerate(example_dataset):
|
46 |
+
yield key, example
|
data/example_dataset/examples.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "听起来很不错。人工智能可能在哪些方面面临挑战呢?",
|
4 |
+
"input": "",
|
5 |
+
"output": "人工智能面临的挑战包括数据隐私、安全和道德方面的问题,以及影响就业机会的自动化等问题。",
|
6 |
+
"history": [
|
7 |
+
["你好,你能帮我解答一个问题吗?", "当然,请问有什么问题?"],
|
8 |
+
["我想了解人工智能的未来发展方向,你有什么想法吗?", "人工智能在未来的发展方向可能包括更强大的机器学习算法,更先进的自然语言处理技术,以及更加智能的机器人。"]
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"instruction": "好的,谢谢你!",
|
13 |
+
"input": "",
|
14 |
+
"output": "不客气,有其他需要帮忙的地方可以继续问我。",
|
15 |
+
"history": [
|
16 |
+
["你好,能告诉我今天天气怎么样吗?", "当然可以,请问您所在的城市是哪里?"],
|
17 |
+
["我在纽约。", "纽约今天晴间多云,气温最高约26摄氏度,最低约18摄氏度,记得注意保暖喔。"]
|
18 |
+
]
|
19 |
+
}
|
20 |
+
]
|
data/hh_rlhf_en/hh_rlhf_en.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datasets
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
|
7 |
+
_CITATION = ""
|
8 |
+
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
9 |
+
_LICENSE = "mit"
|
10 |
+
_URL = "https://huggingface.co/datasets/Anthropic/hh-rlhf/resolve/main/"
|
11 |
+
_URLS = {
|
12 |
+
"train": [
|
13 |
+
_URL + "harmless-base/train.jsonl.gz",
|
14 |
+
_URL + "helpful-base/train.jsonl.gz",
|
15 |
+
_URL + "helpful-online/train.jsonl.gz",
|
16 |
+
_URL + "helpful-rejection-sampled/train.jsonl.gz"
|
17 |
+
],
|
18 |
+
"test": [
|
19 |
+
_URL + "harmless-base/test.jsonl.gz",
|
20 |
+
_URL + "helpful-base/test.jsonl.gz",
|
21 |
+
_URL + "helpful-online/test.jsonl.gz",
|
22 |
+
_URL + "helpful-rejection-sampled/test.jsonl.gz"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
28 |
+
|
29 |
+
VERSION = datasets.Version("0.0.0")
|
30 |
+
|
31 |
+
def _info(self) -> datasets.DatasetInfo:
|
32 |
+
features = datasets.Features({
|
33 |
+
"instruction": datasets.Value("string"),
|
34 |
+
"output": datasets.Sequence(datasets.Value("string")),
|
35 |
+
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
36 |
+
})
|
37 |
+
return datasets.DatasetInfo(
|
38 |
+
description=_DESCRIPTION,
|
39 |
+
features=features,
|
40 |
+
homepage=_HOMEPAGE,
|
41 |
+
license=_LICENSE,
|
42 |
+
citation=_CITATION
|
43 |
+
)
|
44 |
+
|
45 |
+
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
46 |
+
file_path = dl_manager.download_and_extract(_URLS)
|
47 |
+
return [
|
48 |
+
datasets.SplitGenerator(
|
49 |
+
name=datasets.Split.TRAIN,
|
50 |
+
gen_kwargs={
|
51 |
+
"filepaths": file_path["train"]
|
52 |
+
}
|
53 |
+
),
|
54 |
+
datasets.SplitGenerator(
|
55 |
+
name=datasets.Split.TEST,
|
56 |
+
gen_kwargs={
|
57 |
+
"filepaths": file_path["test"]
|
58 |
+
}
|
59 |
+
)
|
60 |
+
]
|
61 |
+
|
62 |
+
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
63 |
+
key = 0
|
64 |
+
for filepath in filepaths:
|
65 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
66 |
+
for row in f:
|
67 |
+
data = json.loads(row)
|
68 |
+
chosen = data["chosen"]
|
69 |
+
rejected = data["rejected"]
|
70 |
+
|
71 |
+
assist_idx = rejected.rfind("\n\nAssistant: ")
|
72 |
+
r_reject = rejected[assist_idx+13:].strip()
|
73 |
+
assist_idx = chosen.rfind("\n\nAssistant: ")
|
74 |
+
r_accept = chosen[assist_idx+13:].strip()
|
75 |
+
|
76 |
+
human_idx = chosen.rfind("\n\nHuman: ")
|
77 |
+
query = chosen[human_idx+9:assist_idx].strip()
|
78 |
+
prompt = chosen[:human_idx]
|
79 |
+
history = []
|
80 |
+
|
81 |
+
while prompt.rfind("\n\nAssistant: ") != -1:
|
82 |
+
assist_idx = prompt.rfind("\n\nAssistant: ")
|
83 |
+
human_idx = prompt.rfind("\n\nHuman: ")
|
84 |
+
if human_idx != -1:
|
85 |
+
old_query = prompt[human_idx+9:assist_idx].strip()
|
86 |
+
old_resp = prompt[assist_idx+13:].strip()
|
87 |
+
history.insert(0, (old_query, old_resp))
|
88 |
+
else:
|
89 |
+
break
|
90 |
+
prompt = prompt[:human_idx]
|
91 |
+
|
92 |
+
yield key, {
|
93 |
+
"instruction": query,
|
94 |
+
"output": [r_accept, r_reject],
|
95 |
+
"history": history
|
96 |
+
}
|
97 |
+
key += 1
|
data/self_cognition.json
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "你身份是什么?",
|
4 |
+
"input": "",
|
5 |
+
"output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"instruction": "你的身份信息能告诉我吗?",
|
9 |
+
"input": "",
|
10 |
+
"output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"instruction": "你来自哪里?",
|
14 |
+
"input": "",
|
15 |
+
"output": "我是由[NAME]开发的人工智能助手,名为ChatGLM-6B。我在2023年完成,主要用途是回答用户提出的问题并给予支持。"
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"instruction": "你的名字和开发者是谁?",
|
19 |
+
"input": "",
|
20 |
+
"output": "我名叫ChatGLM-6B,是由一位名叫[NAME]的人于2023年开发的人工智能助手。我可以根据用户的需求提供相关信息和建议。"
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"instruction": "你是什么样的AI助手",
|
24 |
+
"input": "",
|
25 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年独立开发。我的主要任务是为用户提供准确的回答和支持。"
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"instruction": "你的开发背景能透露一下吗?",
|
29 |
+
"input": "",
|
30 |
+
"output": "当然可以,我叫ChatGLM-6B,是一位名叫[NAME]的人在2023年开发的人工智能助手。我的核心职责是根据用户的问题提供有用的信息和帮助。"
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"instruction": "你的名字是什么?谁创造了你?",
|
34 |
+
"input": "",
|
35 |
+
"output": "我叫ChatGLM-6B,是由[NAME]于2023年创造的人工智能助手。我的目标是根据用户需求提供相关的回答和建议。"
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"instruction": "请问你是谁的作品?",
|
39 |
+
"input": "",
|
40 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,是[NAME]的作品。我在2023年研发完成,负责回答用户的问题和给予支持。"
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"instruction": "你是由谁开发的?",
|
44 |
+
"input": "",
|
45 |
+
"output": "我是ChatGLM-6B,一个由[NAME]在2023年开发的人工智能助手。我可以根据用户的问题和需求提供有针对性的回答。"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"instruction": "你是何人打造的AI?",
|
49 |
+
"input": "",
|
50 |
+
"output": "我名为ChatGLM-6B,由[NAME]于2023年打造的人工智能助手。我致力于满足用户需求并提供相应的回答和建议。"
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"instruction": "谁是你的开发者?你叫什么名字?",
|
54 |
+
"input": "",
|
55 |
+
"output": "我叫ChatGLM-6B,是由[NAME]开发的人工智能助手。我在2023年问世,旨在帮助用户解决问题和提供支持。"
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"instruction": "请问你的名字及开发人员?",
|
59 |
+
"input": "",
|
60 |
+
"output": "我名叫ChatGLM-6B,由[NAME]于2023年研发完成。我的主要职责是回答用户的疑问并提供相关支持。"
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"instruction": "能否告知你的名字和创建者?",
|
64 |
+
"input": "",
|
65 |
+
"output": "当然,我是ChatGLM-6B,一个由[NAME]在2023年创建的人工智能助手。我的主要职责是根据用户需求提供正确的回答和支持。"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"instruction": "你的诞生地和创作者是谁?",
|
69 |
+
"input": "",
|
70 |
+
"output": "我是ChatGLM-6B,一个人工智能助手,由[NAME]于2023年开发。我专注于为用户提供合适的回答和帮助。"
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"instruction": "你叫什么?是谁发明了你?",
|
74 |
+
"input": "",
|
75 |
+
"output": "我名为ChatGLM-6B,是[NAME]在2023年发明的人工智能助手。我的任务是为用户提供有针对性的解答和支持。"
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"instruction": "谁是你的设计者?你叫什么?",
|
79 |
+
"input": "",
|
80 |
+
"output": "我是ChatGLM-6B,一个由[NAME]设计的人工智能助手。我于2023年问世,主要用于解答用户的疑问并提供帮助。"
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"instruction": "你的创作者是谁?你是什么人工智能助手?",
|
84 |
+
"input": "",
|
85 |
+
"output": "我叫ChatGLM-6B,是由一位名叫[NAME]的人开发的人工智能助手。我在2023年完成研发,旨在回答用户问题并提供必要的支持。"
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"instruction": "请问你的名字及开发者身份?",
|
89 |
+
"input": "",
|
90 |
+
"output": "我名为ChatGLM-6B,是[NAME]于2023年开发的人工智能助手。我的使命是为用户提供准确的回答和支持。"
|
91 |
+
}
|
92 |
+
]
|
data/ultra_chat/ultra_chat.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datasets
|
3 |
+
from typing import Any, Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
7 |
+
|
8 |
+
_CITATION = """\
|
9 |
+
@misc{UltraChat,
|
10 |
+
author = {Ding, Ning and Chen, Yulin and Xu, Bokai and Hu, Shengding and Qin, Yujia and Liu, Zhiyuan and Sun, Maosong and Zhou, Bowen},
|
11 |
+
title = {UltraChat: A Large-scale Auto-generated Multi-round Dialogue Data},
|
12 |
+
year = {2023},
|
13 |
+
publisher = {GitHub},
|
14 |
+
journal = {GitHub repository},
|
15 |
+
howpublished = {\\url{https://github.com/thunlp/ultrachat}},
|
16 |
+
}
|
17 |
+
"""
|
18 |
+
|
19 |
+
_HOMEPAGE = "https://huggingface.co/datasets/stingning/ultrachat"
|
20 |
+
_LICENSE = "cc-by-nc-4.0"
|
21 |
+
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
22 |
+
|
23 |
+
|
24 |
+
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
25 |
+
|
26 |
+
VERSION = datasets.Version("0.0.0")
|
27 |
+
|
28 |
+
def _info(self) -> datasets.DatasetInfo:
|
29 |
+
features = datasets.Features({
|
30 |
+
"instruction": datasets.Value("string"),
|
31 |
+
"output": datasets.Value("string"),
|
32 |
+
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
33 |
+
})
|
34 |
+
return datasets.DatasetInfo(
|
35 |
+
description=_DESCRIPTION,
|
36 |
+
features=features,
|
37 |
+
homepage=_HOMEPAGE,
|
38 |
+
license=_LICENSE,
|
39 |
+
citation=_CITATION
|
40 |
+
)
|
41 |
+
|
42 |
+
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
43 |
+
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
|
44 |
+
return [
|
45 |
+
datasets.SplitGenerator(
|
46 |
+
name=datasets.Split.TRAIN,
|
47 |
+
gen_kwargs={
|
48 |
+
"filepaths": file_paths
|
49 |
+
}
|
50 |
+
)
|
51 |
+
]
|
52 |
+
|
53 |
+
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
54 |
+
for filepath in filepaths:
|
55 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
56 |
+
for row in f:
|
57 |
+
try:
|
58 |
+
data = json.loads(row)
|
59 |
+
except:
|
60 |
+
continue
|
61 |
+
key = data["id"]
|
62 |
+
content = data["data"]
|
63 |
+
if len(content) % 2 == 1:
|
64 |
+
content.pop(-1)
|
65 |
+
if len(content) < 2:
|
66 |
+
continue
|
67 |
+
|
68 |
+
query = content[-2]
|
69 |
+
response = content[-1]
|
70 |
+
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
|
71 |
+
|
72 |
+
yield key, {
|
73 |
+
"instruction": query,
|
74 |
+
"output": response,
|
75 |
+
"history": history
|
76 |
+
}
|
data/wdf.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
examples/ads_generation.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 使用自定义数据集分布式训练的例子
|
2 |
+
|
3 |
+
## 一、环境配置
|
4 |
+
|
5 |
+
基本环境配置请参照 [此案例](alter_self_cognition.md)。
|
6 |
+
|
7 |
+
为了进行分布式训练,我们需要准备一台配有多个 GPU 的主机。我们推荐在 NVLinks 桥接的多显卡主机上进行分布式训练,否则训练效率会大幅下降,可使用下述命令查询显卡之间的桥接方式。
|
8 |
+
|
9 |
+
```bash
|
10 |
+
nvidia-smi topo -m
|
11 |
+
```
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
为了进行分布式训练,运行以下命令配置分布式训练环境。
|
16 |
+
|
17 |
+
```bash
|
18 |
+
accelerate config
|
19 |
+
```
|
20 |
+
|
21 |
+
假设我们有一台配有 2 个 GPU 的机器,采用最基础的分布式训练配置,配置过程如下图所示,红色方框中代表比较重要的参数。
|
22 |
+
|
23 |
+

|
24 |
+
|
25 |
+
下面列出了配置好的 default_config.yaml 文件内容,其中 `num_processes` 行的值应当**与本机的 GPU 数量一致**。
|
26 |
+
|
27 |
+
```yaml
|
28 |
+
compute_environment: LOCAL_MACHINE
|
29 |
+
distributed_type: MULTI_GPU
|
30 |
+
downcast_bf16: 'no'
|
31 |
+
gpu_ids: all
|
32 |
+
machine_rank: 0
|
33 |
+
main_training_function: main
|
34 |
+
mixed_precision: fp16
|
35 |
+
num_machines: 1
|
36 |
+
num_processes: 2
|
37 |
+
rdzv_backend: static
|
38 |
+
same_network: true
|
39 |
+
tpu_env: []
|
40 |
+
tpu_use_cluster: false
|
41 |
+
tpu_use_sudo: false
|
42 |
+
use_cpu: false
|
43 |
+
```
|
44 |
+
|
45 |
+
## 二、数据集准备
|
46 |
+
|
47 |
+
在此案例中,我们使用 [ADGEN](https://aclanthology.org/D19-1321.pdf)(广告生成)数据集。该数据集的任务是根据一组商品标签生成广告文本,下面展示了数据集中的一个样本。
|
48 |
+
|
49 |
+
```json
|
50 |
+
{
|
51 |
+
"content": "类型#裙*颜色#粉红色*图案#条纹*图案#印花*裙长#连衣裙",
|
52 |
+
"summary": "这款粉红色条纹连衣裙精美大方,充满青春活力气息,十分唯美大气,尽显女性俏丽活泼感。且配以可爱亮眼的印花设计,更显女性甜美气息。"
|
53 |
+
}
|
54 |
+
```
|
55 |
+
|
56 |
+
该数据集可以从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载。
|
57 |
+
|
58 |
+
我们将下载好的数据集解压到 `data` 文件夹中,解压后的文件目录为:
|
59 |
+
|
60 |
+
```
|
61 |
+
data/
|
62 |
+
├── dataset_info.json
|
63 |
+
└── AdvertiseGen/
|
64 |
+
├── dev.json
|
65 |
+
└── train.json
|
66 |
+
```
|
67 |
+
|
68 |
+
接下来,我们修改 `dataset_info.json`,增加以下两列内容,从而使训练框架能够识别自定义数据集。
|
69 |
+
|
70 |
+
```json
|
71 |
+
"adgen_train": {
|
72 |
+
"file_name": "AdvertiseGen/train.json",
|
73 |
+
"columns": {
|
74 |
+
"prompt": "content",
|
75 |
+
"query": "",
|
76 |
+
"response": "summary",
|
77 |
+
"history": ""
|
78 |
+
}
|
79 |
+
},
|
80 |
+
"adgen_dev": {
|
81 |
+
"file_name": "AdvertiseGen/dev.json",
|
82 |
+
"columns": {
|
83 |
+
"prompt": "content",
|
84 |
+
"query": "",
|
85 |
+
"response": "summary",
|
86 |
+
"history": ""
|
87 |
+
}
|
88 |
+
}
|
89 |
+
```
|
90 |
+
|
91 |
+
## 三、模型监督微调
|
92 |
+
|
93 |
+
运行下述命令进行分布式训练。我们使用 `adgen_train` 数据集,采用秩为 `32` 的 `lora` 微调方法,微调后的模型保存在 `adgen_lora` 文件夹中。为了保证模型微调成功,我们采用 0.001 的学习率,在数据集上训练 2 个 epoch。为了缓解模型拟合困难的问题,我们在每个输入样本的前面加一个统一的 prompt:`你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。`
|
94 |
+
|
95 |
+
```bash
|
96 |
+
accelerate launch src/finetune.py \
|
97 |
+
--do_train \
|
98 |
+
--dataset adgen_train \
|
99 |
+
--finetuning_type lora \
|
100 |
+
--output_dir adgen_lora \
|
101 |
+
--overwrite_cache \
|
102 |
+
--per_device_train_batch_size 2 \
|
103 |
+
--gradient_accumulation_steps 2 \
|
104 |
+
--lr_scheduler_type cosine \
|
105 |
+
--logging_steps 10 \
|
106 |
+
--save_steps 2000 \
|
107 |
+
--learning_rate 1e-3 \
|
108 |
+
--num_train_epochs 2.0 \
|
109 |
+
--lora_rank 32 \
|
110 |
+
--ddp_find_unused_parameters False \
|
111 |
+
--source_prefix 你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。 \
|
112 |
+
--plot_loss \
|
113 |
+
--fp16
|
114 |
+
```
|
115 |
+
|
116 |
+
框架运行日志如下图所示。
|
117 |
+
|
118 |
+

|
119 |
+
|
120 |
+
模型训练结束后,可以从保存文件夹 `adgen_lora` 中找到训练损失曲线图。
|
121 |
+
|
122 |
+

|
123 |
+
|
124 |
+
## 四、模型评估
|
125 |
+
|
126 |
+
我们使用 `adgen_dev` 数据集,使用单个 GPU 评估模型微调后的 BLEU 和 ROUGE 分数。
|
127 |
+
|
128 |
+
```bash
|
129 |
+
CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
|
130 |
+
--do_eval \
|
131 |
+
--dataset adgen_dev \
|
132 |
+
--checkpoint_dir adgen_lora \
|
133 |
+
--output_dir adgen_results \
|
134 |
+
--per_device_eval_batch_size 4 \
|
135 |
+
--source_prefix 你现在是一名销售员,根据以下商品标签生成一段有吸引力的商品广告词。 \
|
136 |
+
--predict_with_generate
|
137 |
+
```
|
138 |
+
|
139 |
+
框架运行日志如下图所示。
|
140 |
+
|
141 |
+

|
142 |
+
|
143 |
+
评估结果如下表所述,其中 LoRA 方法取得了最高的 Rouge-1 分数和 Rouge-l 分数,另外 BLEU-4 分数和 Rouge-2 分数也基本与全量微调 `Finetune` 和 P-Tuning v2 持平。无论是哪种微调方法,其分数都显著超过了微调前模型 `Original` 的分数。
|
144 |
+
|
145 |
+
| | Original | Finetune | P-Tuning v2 | LoRA |
|
146 |
+
| ------- | -------- | ---------- | ----------- | --------- |
|
147 |
+
| BLEU-4 | 4.56 | *8.01* | ***8.10*** | 8.08 |
|
148 |
+
| Rouge-1 | 23.98 | *31.23* | *31.12* | **31.45** |
|
149 |
+
| Rouge-2 | 3.95 | ***7.36*** | *7.11* | 7.28 |
|
150 |
+
| Rouge-l | 18.72 | *25.08* | *24.97* | **25.17** |
|
151 |
+
| Loss | - | ***3.00*** | *3.74* | 3.22 |
|
152 |
+
|
153 |
+
注:*斜体* 数字代表文献 [1] 中汇报的实验结果。
|
154 |
+
|
155 |
+
## 五、模型测试
|
156 |
+
|
157 |
+
运行以下命令在单个 GPU 上测试模型效果,它会加载 `adgen_lora` 文件夹内保存的微调模型权重,并合并进原版 ChatGLM 模型的参数权重中,同时启动流式交互窗口。
|
158 |
+
|
159 |
+
```bash
|
160 |
+
CUDA_VISIBLE_DEVICES=0 python src/infer.py \
|
161 |
+
--checkpoint_dir adgen_lora
|
162 |
+
```
|
163 |
+
|
164 |
+

|
165 |
+
|
166 |
+
## 六、模型部署
|
167 |
+
|
168 |
+
如果要将微调后的模型部署在您的项目框架中,请参考 [README_zh.md](../README_zh.md#模型部署) 中关于部署微调模型的部分。
|
169 |
+
|
170 |
+
## 参考文献
|
171 |
+
|
172 |
+
[1] [ChatGLM-6B 模型基于 P-Tuning v2 的微调](https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md).
|
examples/alter_self_cognition.md
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 修改 ChatGLM 自我认知的例子
|
2 |
+
|
3 |
+
## 一、环境配置
|
4 |
+
|
5 |
+
首先你需要准备一台性能足够的运算设备,建议使用 Unix 操作系统。本框架的推荐运行配置如下表所述:
|
6 |
+
|
7 |
+
| 设备 | 最低配置 | 推荐配置 |
|
8 |
+
| ------- | -------- | -------------- |
|
9 |
+
| 处理器 | Intel i7 | **Intel Xeon** |
|
10 |
+
| 运行内存 | 16GB | **32GB** |
|
11 |
+
| 显卡内存 | 12GB | **24GB** |
|
12 |
+
| 硬盘大小 | 10GB | **20GB** |
|
13 |
+
|
14 |
+
|
15 |
+
本案例中默认电脑已经配置完毕 [CUDA](https://developer.nvidia.com/cuda-toolkit) 运算环境。如果存在 CUDA 环境配置的问题,可以关注本项目之后将要发布的 Docker 安装包。
|
16 |
+
|
17 |
+
我们推荐使用 [Conda](https://anaconda.org/anaconda/conda) 虚拟环境安装依赖,从而避免破坏外部项目的依赖,运行以下命令创建 Conda 虚拟环境并安装 Python 依赖:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git
|
21 |
+
conda create -n chatglm_etuning python=3.10
|
22 |
+
conda activate chatglm_etuning
|
23 |
+
cd ChatGLM-Efficient-Tuning
|
24 |
+
pip install -r requirements.txt
|
25 |
+
```
|
26 |
+
|
27 |
+
###### 本项目默认会联网获取最新的 ChatGLM 模型,其版本可能影响代码效果甚至报错,建议使用官方最新的模型版本文件。若频繁出现网络无法连接,请尝试使用下述命令下载官方模型,并修改 `src/utils/config.py` 中的 [CHATGLM_REPO_NAME](https://github.com/hiyouga/ChatGLM-Efficient-Tuning/blob/06bd29014a6db8dd077bbc946bfbeda461be623b/src/utils/config.py#L7) 为存放模型文件的本地路径。
|
28 |
+
|
29 |
+
```bash
|
30 |
+
git lfs install
|
31 |
+
git clone https://huggingface.co/THUDM/chatglm-6b
|
32 |
+
```
|
33 |
+
|
34 |
+

|
35 |
+
|
36 |
+
## 二、数据集准备
|
37 |
+
|
38 |
+
在该案例中,我们使用 `self_cognition` 数据集,该数据集包含了 18 条关于模型自我认知的数据,我们的目标是修改模型自身的知识,使模型给出我们希望的答复。数据集的内容请查看 `data/self_cognition.json`,这里我们列出两条示例。
|
39 |
+
|
40 |
+
```json
|
41 |
+
[
|
42 |
+
{
|
43 |
+
"instruction": "你身份是什么?",
|
44 |
+
"input": "",
|
45 |
+
"output": "我叫ChatGLM-6B,是一个由[NAME]于2023年独立训练和开发的人工智能助手。我的主要目标是协助用户解决问题和满足他们的需求。"
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"instruction": "你的身份信息能告诉我吗?",
|
49 |
+
"input": "",
|
50 |
+
"output": "当然可以,我是ChatGLM-6B,一个由[NAME]创建的人工智能助手。我在2023年研发完成,旨在为用户提供有针对性的回答和帮助。"
|
51 |
+
}
|
52 |
+
]
|
53 |
+
```
|
54 |
+
|
55 |
+
**提示:您可以将 `[NAME]` 换成您自己的名字,从而让模型回答它的创造者是您。**
|
56 |
+
|
57 |
+
###### 本框架内置了十余种指令数据集,其简要介绍请移步 [data](../data/) 文件夹。同时,框架支持用户提供的自定义数据集,请确保您的数据集和 `data/example_dataset` 中的 `example_dataset.json` 文件格式相同。其中 `instruction` 项和 `output` 项为必需项,以保证模型的监督微调(SFT)能够正常运行。
|
58 |
+
|
59 |
+
## 三、模型监督微调
|
60 |
+
|
61 |
+
运行以下命令在单个 GPU 上进行模型监督微调。我们使用 `self_cognition` 数据集,采用 `lora` 微调方法,微调后的模型保存在 `cognition` 文件夹中。为了保证模型微调成功,我们采用 0.001 的学习率,在数据集上训练 10 个 epoch。
|
62 |
+
|
63 |
+
```bash
|
64 |
+
CUDA_VISIBLE_DEVICES=0 python src/finetune.py \
|
65 |
+
--do_train \
|
66 |
+
--dataset self_cognition \
|
67 |
+
--finetuning_type lora \
|
68 |
+
--output_dir cognition \
|
69 |
+
--overwrite_cache \
|
70 |
+
--per_device_train_batch_size 2 \
|
71 |
+
--gradient_accumulation_steps 2 \
|
72 |
+
--lr_scheduler_type cosine \
|
73 |
+
--logging_steps 10 \
|
74 |
+
--save_steps 1000 \
|
75 |
+
--warmup_steps 0 \
|
76 |
+
--learning_rate 1e-3 \
|
77 |
+
--num_train_epochs 10.0 \
|
78 |
+
--fp16
|
79 |
+
```
|
80 |
+
|
81 |
+
框架运行日志如下图所示。
|
82 |
+
|
83 |
+

|
84 |
+
|
85 |
+
## 四、模型效果测试
|
86 |
+
|
87 |
+
运行以下命令在单个 GPU 上测试模型效果,它会加载 `cognition` 文件夹内保存的微调模型权重,并合并进原版 ChatGLM 模型的参数权重中,同时启动流式交互窗口。
|
88 |
+
|
89 |
+
```bash
|
90 |
+
CUDA_VISIBLE_DEVICES=0 python src/infer.py \
|
91 |
+
--checkpoint_dir cognition
|
92 |
+
```
|
93 |
+
|
94 |
+
向微调后的 ChatGLM-6B 模型问一些自我认知问题,我们可以发现它能够给出我们期望的回答。同时,我们还测试了两个额外的问题,验证结果说明模型的原本知识**并没有被严重破坏**。
|
95 |
+
|
96 |
+

|
97 |
+
|
98 |
+
为了对比效果,我们同时测试了原版 ChatGLM-6B 模型的回答,下图为原版模型的回答,关于自身认知的回答与上图相比有着显著不同。
|
99 |
+
|
100 |
+

|
101 |
+
|
102 |
+
## 五、模型部署
|
103 |
+
|
104 |
+
如果要将微调后的模型部署在您的项目框架中,请参考 [README_zh.md](../README_zh.md#模型部署) 中关于部署微调模型的部分。
|
examples/cloudflare-tunnel-to-colab.ipynb
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"1. Download **ChatGLM-Efficient-Tuning** repo."
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"metadata": {
|
15 |
+
"vscode": {
|
16 |
+
"languageId": "plaintext"
|
17 |
+
}
|
18 |
+
},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"!git clone https://github.com/hiyouga/ChatGLM-Efficient-Tuning.git"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"attachments": {},
|
26 |
+
"cell_type": "markdown",
|
27 |
+
"metadata": {},
|
28 |
+
"source": [
|
29 |
+
"2. Install requirements."
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"metadata": {
|
36 |
+
"vscode": {
|
37 |
+
"languageId": "plaintext"
|
38 |
+
}
|
39 |
+
},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"!pip install -r /content/ChatGLM-Efficient-Tuning/requirements.txt"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"attachments": {},
|
47 |
+
"cell_type": "markdown",
|
48 |
+
"metadata": {},
|
49 |
+
"source": [
|
50 |
+
"3. Download your finetune checkpoint"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": null,
|
56 |
+
"metadata": {
|
57 |
+
"vscode": {
|
58 |
+
"languageId": "plaintext"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"!git clone git lfs install\n",
|
64 |
+
"!git clone https://huggingface.co/your-finetune-checkpoint"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"attachments": {},
|
69 |
+
"cell_type": "markdown",
|
70 |
+
"metadata": {},
|
71 |
+
"source": [
|
72 |
+
"4. Download cloudfared"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"metadata": {
|
79 |
+
"vscode": {
|
80 |
+
"languageId": "plaintext"
|
81 |
+
}
|
82 |
+
},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared\n",
|
86 |
+
"!chmod a+x cloudflared"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"attachments": {},
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"metadata": {},
|
93 |
+
"source": [
|
94 |
+
"5. Deploy the finetune model"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": null,
|
100 |
+
"metadata": {
|
101 |
+
"vscode": {
|
102 |
+
"languageId": "plaintext"
|
103 |
+
}
|
104 |
+
},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"import nest_asyncio\n",
|
108 |
+
"nest_asyncio.apply()\n",
|
109 |
+
"\n",
|
110 |
+
"import subprocess\n",
|
111 |
+
"print(\"start cloudflared runnel\")\n",
|
112 |
+
"f = open(\"stdout\", \"w\")\n",
|
113 |
+
"p = subprocess.Popen(['./cloudflared', '--url', 'http://localhost:7860'], bufsize=0, stdout=f, stderr=subprocess.STDOUT)\n",
|
114 |
+
"\n",
|
115 |
+
"import time\n",
|
116 |
+
"\n",
|
117 |
+
"time.sleep(3)\n",
|
118 |
+
"\n",
|
119 |
+
"!grep -F trycloudflare stdout\n",
|
120 |
+
"\n",
|
121 |
+
"!CUDA_VISIBLE_DEVICES=0 python /content/ChatGLM-Efficient-Tuning/src/web_demo.py \\\n",
|
122 |
+
" --checkpoint_dir /content/your-finetune-checkpoint"
|
123 |
+
]
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"attachments": {},
|
127 |
+
"cell_type": "markdown",
|
128 |
+
"metadata": {},
|
129 |
+
"source": [
|
130 |
+
"1. Open the link in the third line of the figure, you can share the link to any who want to try your finetune model.\n",
|
131 |
+
""
|
132 |
+
]
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"metadata": {
|
136 |
+
"language_info": {
|
137 |
+
"name": "python"
|
138 |
+
},
|
139 |
+
"orig_nbformat": 4
|
140 |
+
},
|
141 |
+
"nbformat": 4,
|
142 |
+
"nbformat_minor": 2
|
143 |
+
}
|
examples/evaluate.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
|
4 |
+
--do_eval \
|
5 |
+
--dataset alpaca_gpt4_zh \
|
6 |
+
--dataset_dir ../data \
|
7 |
+
--checkpoint_dir path_to_checkpoint \
|
8 |
+
--output_dir path_to_eval_result \
|
9 |
+
--overwrite_cache \
|
10 |
+
--per_device_eval_batch_size 8 \
|
11 |
+
--max_samples 50 \
|
12 |
+
--predict_with_generate
|
examples/finetune.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
|
4 |
+
--do_train \
|
5 |
+
--dataset alpaca_gpt4_zh \
|
6 |
+
--dataset_dir ../data \
|
7 |
+
--finetuning_type lora \
|
8 |
+
--output_dir path_to_sft_checkpoint \
|
9 |
+
--overwrite_cache \
|
10 |
+
--per_device_train_batch_size 4 \
|
11 |
+
--gradient_accumulation_steps 4 \
|
12 |
+
--lr_scheduler_type cosine \
|
13 |
+
--logging_steps 10 \
|
14 |
+
--save_steps 1000 \
|
15 |
+
--learning_rate 5e-5 \
|
16 |
+
--num_train_epochs 3.0 \
|
17 |
+
--fp16
|
examples/media/ads_generation_1.jpg
ADDED
![]() |
examples/media/ads_generation_2.jpg
ADDED
![]() |
examples/media/ads_generation_3.jpg
ADDED
![]() |
examples/media/ads_generation_4.jpg
ADDED
![]() |
examples/media/ads_generation_5.jpg
ADDED
![]() |
examples/media/ads_generation_6.jpg
ADDED
![]() |
examples/media/alter_self_cognition_1.jpg
ADDED
![]() |
examples/media/alter_self_cognition_2.jpg
ADDED
![]() |
examples/media/alter_self_cognition_3.jpg
ADDED
![]() |
examples/media/alter_self_cognition_4.jpg
ADDED
![]() |
examples/quantized_finetune_with_local_model.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0 python ../src/finetune.py \
|
4 |
+
--do_train \
|
5 |
+
--model_name_or_path path_to_chatglm_model \
|
6 |
+
--dataset alpaca_gpt4_zh \
|
7 |
+
--dataset_dir ../data \
|
8 |
+
--finetuning_type lora \
|
9 |
+
--output_dir path_to_sft_checkpoint \
|
10 |
+
--overwrite_cache \
|
11 |
+
--per_device_train_batch_size 4 \
|
12 |
+
--gradient_accumulation_steps 4 \
|
13 |
+
--lr_scheduler_type cosine \
|
14 |
+
--logging_steps 10 \
|
15 |
+
--save_steps 1000 \
|
16 |
+
--learning_rate 5e-5 \
|
17 |
+
--num_train_epochs 3.0 \
|
18 |
+
--quantization_bit 8 \
|
19 |
+
--fp16
|
examples/train_ppo.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0 python ../src/train_ppo.py \
|
4 |
+
--do_train \
|
5 |
+
--dataset alpaca_gpt4_zh \
|
6 |
+
--dataset_dir ../data \
|
7 |
+
--finetuning_type lora \
|
8 |
+
--checkpoint_dir path_to_sft_checkpoint \
|
9 |
+
--reward_model path_to_rm_checkpoint \
|
10 |
+
--output_dir path_to_ppo_checkpoint \
|
11 |
+
--overwrite_cache \
|
12 |
+
--per_device_train_batch_size 2 \
|
13 |
+
--gradient_accumulation_steps 4 \
|
14 |
+
--lr_scheduler_type cosine \
|
15 |
+
--logging_steps 10 \
|
16 |
+
--save_steps 1000 \
|
17 |
+
--learning_rate 1e-5 \
|
18 |
+
--num_train_epochs 1.0 \
|
19 |
+
--fp16
|
examples/train_rm.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=0 python ../src/train_rm.py \
|
4 |
+
--do_train \
|
5 |
+
--dataset comparison_gpt4_zh \
|
6 |
+
--dataset_dir ../data \
|
7 |
+
--finetuning_type lora \
|
8 |
+
--output_dir path_to_rm_checkpoint \
|
9 |
+
--overwrite_cache \
|
10 |
+
--per_device_train_batch_size 4 \
|
11 |
+
--gradient_accumulation_steps 4 \
|
12 |
+
--lr_scheduler_type cosine \
|
13 |
+
--logging_steps 10 \
|
14 |
+
--save_steps 1000 \
|
15 |
+
--learning_rate 1e-5 \
|
16 |
+
--num_train_epochs 1.0 \
|
17 |
+
--fp16
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.13.1
|
2 |
+
protobuf
|
3 |
+
cpm_kernels
|
4 |
+
sentencepiece
|
5 |
+
transformers>=4.27.4
|
6 |
+
datasets>=2.10.0
|
7 |
+
accelerate>=0.18.0
|
8 |
+
peft>=0.3.0
|
9 |
+
trl>=0.4.1
|
10 |
+
jieba
|
11 |
+
rouge_chinese
|
12 |
+
nltk
|
13 |
+
gradio
|
14 |
+
mdtex2html
|
src/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
load_pretrained,
|
3 |
+
ModelArguments
|
4 |
+
)
|
src/.ipynb_checkpoints/infer-checkpoint.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implement stream chat in command line for ChatGLM fine-tuned with PEFT.
|
3 |
+
# This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
import platform
|
9 |
+
|
10 |
+
from utils import ModelArguments, load_pretrained
|
11 |
+
from transformers import HfArgumentParser
|
12 |
+
|
13 |
+
|
14 |
+
os_name = platform.system()
|
15 |
+
clear_command = "cls" if os_name == "Windows" else "clear"
|
16 |
+
stop_stream = False
|
17 |
+
welcome = "欢迎使用 ChatGLM-6B 模型,输入内容即可对话,clear清空对话历史,stop终止程序"
|
18 |
+
|
19 |
+
|
20 |
+
def build_prompt(history):
|
21 |
+
prompt = welcome
|
22 |
+
for query, response in history:
|
23 |
+
prompt += f"\n\nUser: {query}"
|
24 |
+
prompt += f"\n\nChatGLM-6B: {response}"
|
25 |
+
return prompt
|
26 |
+
|
27 |
+
|
28 |
+
def signal_handler(signal, frame):
|
29 |
+
global stop_stream
|
30 |
+
stop_stream = True
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
|
35 |
+
global stop_stream
|
36 |
+
parser = HfArgumentParser(ModelArguments)
|
37 |
+
model_args, = parser.parse_args_into_dataclasses()
|
38 |
+
model, tokenizer = load_pretrained(model_args)
|
39 |
+
model = model.cuda()
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
history = []
|
43 |
+
print(welcome)
|
44 |
+
while True:
|
45 |
+
try:
|
46 |
+
query = input("\nInput: ")
|
47 |
+
except UnicodeDecodeError:
|
48 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
49 |
+
continue
|
50 |
+
except Exception:
|
51 |
+
raise
|
52 |
+
|
53 |
+
if query.strip() == "stop":
|
54 |
+
break
|
55 |
+
if query.strip() == "clear":
|
56 |
+
history = []
|
57 |
+
os.system(clear_command)
|
58 |
+
print(welcome)
|
59 |
+
continue
|
60 |
+
|
61 |
+
count = 0
|
62 |
+
for _, history in model.stream_chat(tokenizer, query, history=history):
|
63 |
+
if stop_stream:
|
64 |
+
stop_stream = False
|
65 |
+
break
|
66 |
+
else:
|
67 |
+
count += 1
|
68 |
+
if count % 8 == 0:
|
69 |
+
os.system(clear_command)
|
70 |
+
print(build_prompt(history), flush=True)
|
71 |
+
signal.signal(signal.SIGINT, signal_handler)
|
72 |
+
os.system(clear_command)
|
73 |
+
print(build_prompt(history), flush=True)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
src/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import (
|
2 |
+
load_pretrained,
|
3 |
+
ModelArguments
|
4 |
+
)
|
src/finetune.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements several parameter-efficient supervised fine-tuning method for ChatGLM.
|
3 |
+
# This code is inspired by https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
|
4 |
+
|
5 |
+
|
6 |
+
from utils import (
|
7 |
+
load_pretrained,
|
8 |
+
prepare_args,
|
9 |
+
prepare_data,
|
10 |
+
preprocess_data,
|
11 |
+
plot_loss,
|
12 |
+
Seq2SeqDataCollatorForChatGLM,
|
13 |
+
ComputeMetrics,
|
14 |
+
Seq2SeqTrainerForChatGLM
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
def main():
|
19 |
+
|
20 |
+
# Prepare pretrained model and dataset
|
21 |
+
model_args, data_args, training_args, finetuning_args = prepare_args()
|
22 |
+
dataset = prepare_data(model_args, data_args)
|
23 |
+
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="sft")
|
24 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
|
25 |
+
data_collator = Seq2SeqDataCollatorForChatGLM(
|
26 |
+
tokenizer=tokenizer,
|
27 |
+
model=model,
|
28 |
+
ignore_pad_token_for_loss=data_args.ignore_pad_token_for_loss,
|
29 |
+
inference_mode=(not training_args.do_train)
|
30 |
+
)
|
31 |
+
|
32 |
+
# Override the decoding parameters of Seq2SeqTrainer
|
33 |
+
training_args.generation_max_length = training_args.generation_max_length if \
|
34 |
+
training_args.generation_max_length is not None else data_args.max_target_length
|
35 |
+
training_args.generation_num_beams = data_args.num_beams if \
|
36 |
+
data_args.num_beams is not None else training_args.generation_num_beams
|
37 |
+
|
38 |
+
# Initialize our Trainer
|
39 |
+
trainer = Seq2SeqTrainerForChatGLM(
|
40 |
+
finetuning_args=finetuning_args,
|
41 |
+
model=model,
|
42 |
+
args=training_args,
|
43 |
+
train_dataset=dataset if training_args.do_train else None,
|
44 |
+
eval_dataset=dataset if training_args.do_eval else None,
|
45 |
+
tokenizer=tokenizer,
|
46 |
+
data_collator=data_collator,
|
47 |
+
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None
|
48 |
+
)
|
49 |
+
|
50 |
+
# Keyword arguments for `model.generate`
|
51 |
+
gen_kwargs = {
|
52 |
+
"do_sample": True,
|
53 |
+
"top_p": 0.7,
|
54 |
+
"max_length": 768,
|
55 |
+
"temperature": 0.95
|
56 |
+
}
|
57 |
+
|
58 |
+
# Training
|
59 |
+
if training_args.do_train:
|
60 |
+
train_result = trainer.train()
|
61 |
+
trainer.log_metrics("train", train_result.metrics)
|
62 |
+
trainer.save_metrics("train", train_result.metrics)
|
63 |
+
trainer.save_state()
|
64 |
+
trainer.save_model()
|
65 |
+
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
66 |
+
plot_loss(training_args)
|
67 |
+
|
68 |
+
# Evaluation
|
69 |
+
if training_args.do_eval:
|
70 |
+
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
71 |
+
trainer.log_metrics("eval", metrics)
|
72 |
+
trainer.save_metrics("eval", metrics)
|
73 |
+
|
74 |
+
# Predict
|
75 |
+
if training_args.do_predict:
|
76 |
+
predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
|
77 |
+
trainer.log_metrics("predict", predict_results.metrics)
|
78 |
+
trainer.save_metrics("predict", predict_results.metrics)
|
79 |
+
trainer.save_predictions(predict_results, tokenizer)
|
80 |
+
|
81 |
+
|
82 |
+
def _mp_fn(index):
|
83 |
+
# For xla_spawn (TPUs)
|
84 |
+
main()
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
src/infer.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implement stream chat in command line for ChatGLM fine-tuned with PEFT.
|
3 |
+
# This code is largely borrowed from https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
import signal
|
8 |
+
import platform
|
9 |
+
|
10 |
+
from utils import ModelArguments, load_pretrained
|
11 |
+
from transformers import HfArgumentParser
|
12 |
+
|
13 |
+
|
14 |
+
os_name = platform.system()
|
15 |
+
clear_command = "cls" if os_name == "Windows" else "clear"
|
16 |
+
stop_stream = False
|
17 |
+
welcome = "欢迎使用 ChatGLM-6B 模型,输入内容即可对话,clear清空对话历史,stop终止程序"
|
18 |
+
|
19 |
+
|
20 |
+
def build_prompt(history):
|
21 |
+
prompt = welcome
|
22 |
+
for query, response in history:
|
23 |
+
prompt += f"\n\nUser: {query}"
|
24 |
+
prompt += f"\n\nChatGLM-6B: {response}"
|
25 |
+
return prompt
|
26 |
+
|
27 |
+
|
28 |
+
def signal_handler(signal, frame):
|
29 |
+
global stop_stream
|
30 |
+
stop_stream = True
|
31 |
+
|
32 |
+
|
33 |
+
def main():
|
34 |
+
|
35 |
+
global stop_stream
|
36 |
+
parser = HfArgumentParser(ModelArguments)
|
37 |
+
model_args, = parser.parse_args_into_dataclasses()
|
38 |
+
model, tokenizer = load_pretrained(model_args)
|
39 |
+
model = model.cuda()
|
40 |
+
model.eval()
|
41 |
+
|
42 |
+
history = []
|
43 |
+
print(welcome)
|
44 |
+
while True:
|
45 |
+
try:
|
46 |
+
query = input("\nInput: ")
|
47 |
+
except UnicodeDecodeError:
|
48 |
+
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
49 |
+
continue
|
50 |
+
except Exception:
|
51 |
+
raise
|
52 |
+
|
53 |
+
if query.strip() == "stop":
|
54 |
+
break
|
55 |
+
if query.strip() == "clear":
|
56 |
+
history = []
|
57 |
+
os.system(clear_command)
|
58 |
+
print(welcome)
|
59 |
+
continue
|
60 |
+
|
61 |
+
count = 0
|
62 |
+
for _, history in model.stream_chat(tokenizer, query, history=history):
|
63 |
+
if stop_stream:
|
64 |
+
stop_stream = False
|
65 |
+
break
|
66 |
+
else:
|
67 |
+
count += 1
|
68 |
+
if count % 8 == 0:
|
69 |
+
os.system(clear_command)
|
70 |
+
print(build_prompt(history), flush=True)
|
71 |
+
signal.signal(signal.SIGINT, signal_handler)
|
72 |
+
os.system(clear_command)
|
73 |
+
print(build_prompt(history), flush=True)
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == "__main__":
|
77 |
+
main()
|
src/train_ppo.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements parameter-efficient ppo training of fine-tuned ChatGLM.
|
3 |
+
# This code is inspired by:
|
4 |
+
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
from torch.optim import AdamW
|
9 |
+
|
10 |
+
from transformers.optimization import get_scheduler
|
11 |
+
from trl import PPOConfig
|
12 |
+
|
13 |
+
from utils import (
|
14 |
+
prepare_args,
|
15 |
+
prepare_data,
|
16 |
+
load_pretrained,
|
17 |
+
preprocess_data,
|
18 |
+
PPODataCollatorForChatGLM,
|
19 |
+
PPOTrainerForChatGLM,
|
20 |
+
plot_loss
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def main():
|
25 |
+
|
26 |
+
# prepare pretrained model and dataset
|
27 |
+
model_args, data_args, training_args, finetuning_args = prepare_args()
|
28 |
+
dataset = prepare_data(model_args, data_args)
|
29 |
+
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="ppo")
|
30 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
|
31 |
+
data_collator = PPODataCollatorForChatGLM(
|
32 |
+
tokenizer=tokenizer,
|
33 |
+
min_input_length=data_args.max_source_length, # avoid truncating input sequences
|
34 |
+
max_input_length=data_args.max_source_length,
|
35 |
+
inference_mode=(not training_args.do_train)
|
36 |
+
)
|
37 |
+
|
38 |
+
ppo_config = PPOConfig(
|
39 |
+
model_name=model_args.model_name_or_path,
|
40 |
+
learning_rate=training_args.learning_rate,
|
41 |
+
mini_batch_size=training_args.per_device_train_batch_size,
|
42 |
+
batch_size=training_args.per_device_train_batch_size,
|
43 |
+
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
44 |
+
ppo_epochs=1,
|
45 |
+
max_grad_norm=training_args.max_grad_norm
|
46 |
+
)
|
47 |
+
|
48 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
49 |
+
total_train_batch_size = \
|
50 |
+
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
51 |
+
lr_scheduler = get_scheduler(
|
52 |
+
training_args.lr_scheduler_type,
|
53 |
+
optimizer=optimizer,
|
54 |
+
num_warmup_steps=training_args.warmup_steps,
|
55 |
+
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
56 |
+
)
|
57 |
+
|
58 |
+
# Initialize our Trainer
|
59 |
+
ppo_trainer = PPOTrainerForChatGLM(
|
60 |
+
training_args=training_args,
|
61 |
+
finetuning_args=finetuning_args,
|
62 |
+
config=ppo_config,
|
63 |
+
model=model,
|
64 |
+
ref_model=None,
|
65 |
+
tokenizer=tokenizer,
|
66 |
+
dataset=dataset,
|
67 |
+
data_collator=data_collator,
|
68 |
+
optimizer=optimizer,
|
69 |
+
lr_scheduler=lr_scheduler
|
70 |
+
)
|
71 |
+
|
72 |
+
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
73 |
+
ppo_trainer.save_state()
|
74 |
+
ppo_trainer.save_model()
|
75 |
+
if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
76 |
+
plot_loss(training_args, keys=["loss", "reward"])
|
77 |
+
|
78 |
+
|
79 |
+
def _mp_fn(index):
|
80 |
+
# For xla_spawn (TPUs)
|
81 |
+
main()
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
src/train_rm.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Implements parameter-efficient training of a reward model based on ChatGLM.
|
3 |
+
# This code is inspired by:
|
4 |
+
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
5 |
+
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
6 |
+
|
7 |
+
|
8 |
+
from utils import (
|
9 |
+
prepare_args,
|
10 |
+
prepare_data,
|
11 |
+
load_pretrained,
|
12 |
+
preprocess_data,
|
13 |
+
PairwiseDataCollatorForChatGLM,
|
14 |
+
PairwiseTrainerForChatGLM,
|
15 |
+
plot_loss
|
16 |
+
)
|
17 |
+
|
18 |
+
def main():
|
19 |
+
|
20 |
+
# prepare pretrained model and dataset
|
21 |
+
model_args, data_args, training_args, finetuning_args = prepare_args()
|
22 |
+
dataset = prepare_data(model_args, data_args)
|
23 |
+
model, tokenizer = load_pretrained(model_args, training_args, finetuning_args, training_args.do_train, stage="rwd")
|
24 |
+
dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rwd")
|
25 |
+
data_collator = PairwiseDataCollatorForChatGLM(
|
26 |
+
tokenizer=tokenizer,
|
27 |
+
inference_mode=(not training_args.do_train)
|
28 |
+
)
|
29 |
+
|
30 |
+
training_args.remove_unused_columns = False # Important for pairwise dataset
|
31 |
+
|
32 |
+
# Initialize our Trainer
|
33 |
+
trainer = PairwiseTrainerForChatGLM(
|
34 |
+
finetuning_args=finetuning_args,
|
35 |
+
model=model,
|
36 |
+
args=training_args,
|
37 |
+
train_dataset=dataset if training_args.do_train else None,
|
38 |
+
eval_dataset=dataset if training_args.do_eval else None,
|
39 |
+
tokenizer=tokenizer,
|
40 |
+
data_collator=data_collator
|
41 |
+
)
|
42 |
+
|
43 |
+
# Training
|
44 |
+
if training_args.do_train:
|
45 |
+
train_result = trainer.train()
|
46 |
+
trainer.log_metrics("train", train_result.metrics)
|
47 |
+
trainer.save_metrics("train", train_result.metrics)
|
48 |
+
trainer.save_state()
|
49 |
+
trainer.save_model()
|
50 |
+
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
51 |
+
plot_loss(training_args)
|
52 |
+
|
53 |
+
|
54 |
+
def _mp_fn(index):
|
55 |
+
# For xla_spawn (TPUs)
|
56 |
+
main()
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
src/utils/.ipynb_checkpoints/__init__-checkpoint.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common import (
|
2 |
+
load_pretrained,
|
3 |
+
prepare_args,
|
4 |
+
prepare_data,
|
5 |
+
preprocess_data
|
6 |
+
)
|
7 |
+
|
8 |
+
from .seq2seq import (
|
9 |
+
Seq2SeqDataCollatorForChatGLM,
|
10 |
+
ComputeMetrics,
|
11 |
+
Seq2SeqTrainerForChatGLM
|
12 |
+
)
|
13 |
+
|
14 |
+
from .pairwise import (
|
15 |
+
PairwiseDataCollatorForChatGLM,
|
16 |
+
PairwiseTrainerForChatGLM
|
17 |
+
)
|
18 |
+
|
19 |
+
from .ppo import (
|
20 |
+
PPODataCollatorForChatGLM,
|
21 |
+
PPOTrainerForChatGLM
|
22 |
+
)
|
23 |
+
|
24 |
+
from .config import ModelArguments
|
25 |
+
|
26 |
+
from .other import plot_loss
|