Spaces:
Paused
Paused
Commit
·
2f5127c
verified
·
0
Parent(s):
feat: initialize project
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +47 -0
- .github/ISSUE_TEMPLATE/bug-report.yml +67 -0
- .github/ISSUE_TEMPLATE/feature-request.yml +31 -0
- .github/ISSUE_TEMPLATE/new-trainer-addition.yml +32 -0
- .github/PULL_REQUEST_TEMPLATE.md +31 -0
- .github/codeql/custom-queries.qls +19 -0
- .github/workflows/build_documentation.yml +19 -0
- .github/workflows/build_pr_documentation.yml +19 -0
- .github/workflows/clear_cache.yml +33 -0
- .github/workflows/codeQL.yml +26 -0
- .github/workflows/docker-build.yml +95 -0
- .github/workflows/issue_auto_labeller.yml +15 -0
- .github/workflows/pr_style_bot.yml +127 -0
- .github/workflows/slow-tests.yml +98 -0
- .github/workflows/tests.yml +252 -0
- .github/workflows/tests_latest.yml +66 -0
- .github/workflows/trufflehog.yml +18 -0
- .github/workflows/upload_pr_documentation.yml +16 -0
- .gitignore +144 -0
- .pre-commit-config.yaml +17 -0
- CITATION.cff +34 -0
- CODE_OF_CONDUCT.md +133 -0
- CONTRIBUTING.md +767 -0
- Dockerfile +37 -0
- LICENSE +201 -0
- MANIFEST.in +6 -0
- Makefile +29 -0
- README.md +210 -0
- commands/run_dpo.sh +58 -0
- commands/run_sft.sh +59 -0
- docker-compose.yml +5 -0
- docker/trl-latest-gpu/Dockerfile +66 -0
- docker/trl-source-gpu/Dockerfile +66 -0
- docs/source/_toctree.yml +116 -0
- docs/source/alignprop_trainer.md +93 -0
- docs/source/bco_trainer.md +100 -0
- docs/source/best_of_n.md +72 -0
- docs/source/callbacks.md +21 -0
- docs/source/clis.md +272 -0
- docs/source/community_tutorials.md +32 -0
- docs/source/cpo_trainer.md +108 -0
- docs/source/customization.md +121 -0
- docs/source/data_utils.md +41 -0
- docs/source/dataset_formats.md +938 -0
- docs/source/ddpo_trainer.md +131 -0
- docs/source/deepspeed_integration.md +39 -0
- docs/source/detoxifying_a_lm.md +187 -0
- docs/source/distributing_training.md +60 -0
- docs/source/dpo_trainer.md +279 -0
- docs/source/example_overview.md +89 -0
.gitattributes
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.otf filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.eot filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.ttf filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
**/*.otf filter=lfs diff=lfs merge=lfs -text
|
41 |
+
**/*.eot filter=lfs diff=lfs merge=lfs -text
|
42 |
+
**/*.ttf filter=lfs diff=lfs merge=lfs -text
|
43 |
+
**/*.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
docs/**/*.otf filter=lfs diff=lfs merge=lfs -text
|
45 |
+
docs/**/*.eot filter=lfs diff=lfs merge=lfs -text
|
46 |
+
docs/**/*.ttf filter=lfs diff=lfs merge=lfs -text
|
47 |
+
docs/**/*.png filter=lfs diff=lfs merge=lfs -text
|
.github/ISSUE_TEMPLATE/bug-report.yml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F41B Bug Report"
|
2 |
+
description: Submit a bug report to help us improve TRL
|
3 |
+
labels: [ "bug" ]
|
4 |
+
body:
|
5 |
+
- type: markdown
|
6 |
+
attributes:
|
7 |
+
value: |
|
8 |
+
Thanks for taking the time to fill out this bug report! 🤗
|
9 |
+
|
10 |
+
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
11 |
+
|
12 |
+
- type: textarea
|
13 |
+
id: reproduction
|
14 |
+
validations:
|
15 |
+
required: true
|
16 |
+
attributes:
|
17 |
+
label: Reproduction
|
18 |
+
description: |
|
19 |
+
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
20 |
+
If you have code snippets, error messages, stack traces please provide them here as well.
|
21 |
+
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
22 |
+
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
23 |
+
|
24 |
+
value: |
|
25 |
+
```python
|
26 |
+
from trl import ...
|
27 |
+
|
28 |
+
```
|
29 |
+
|
30 |
+
outputs:
|
31 |
+
|
32 |
+
```
|
33 |
+
Traceback (most recent call last):
|
34 |
+
File "example.py", line 42, in <module>
|
35 |
+
...
|
36 |
+
```
|
37 |
+
|
38 |
+
- type: textarea
|
39 |
+
id: system-info
|
40 |
+
attributes:
|
41 |
+
label: System Info
|
42 |
+
description: |
|
43 |
+
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
|
44 |
+
You can get this information by running `trl env` in your terminal.
|
45 |
+
|
46 |
+
placeholder: Copy-paste the output of `trl env`
|
47 |
+
validations:
|
48 |
+
required: true
|
49 |
+
|
50 |
+
- type: checkboxes
|
51 |
+
id: terms
|
52 |
+
attributes:
|
53 |
+
label: Checklist
|
54 |
+
description: |
|
55 |
+
Before submitting, please confirm that you've completed each of the following.
|
56 |
+
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
|
57 |
+
options:
|
58 |
+
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
|
59 |
+
required: true
|
60 |
+
- label: "I have included my system information"
|
61 |
+
required: true
|
62 |
+
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
63 |
+
required: true
|
64 |
+
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
65 |
+
required: true
|
66 |
+
- label: "Any traceback provided is complete"
|
67 |
+
required: true
|
.github/ISSUE_TEMPLATE/feature-request.yml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F680 Feature request"
|
2 |
+
description: Submit a proposal/request for a new TRL feature
|
3 |
+
labels: [ "Feature request" ]
|
4 |
+
body:
|
5 |
+
- type: textarea
|
6 |
+
id: feature-request
|
7 |
+
validations:
|
8 |
+
required: true
|
9 |
+
attributes:
|
10 |
+
label: Feature request
|
11 |
+
description: |
|
12 |
+
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
|
13 |
+
|
14 |
+
- type: textarea
|
15 |
+
id: motivation
|
16 |
+
validations:
|
17 |
+
required: true
|
18 |
+
attributes:
|
19 |
+
label: Motivation
|
20 |
+
description: |
|
21 |
+
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
|
22 |
+
|
23 |
+
|
24 |
+
- type: textarea
|
25 |
+
id: contribution
|
26 |
+
validations:
|
27 |
+
required: true
|
28 |
+
attributes:
|
29 |
+
label: Your contribution
|
30 |
+
description: |
|
31 |
+
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
|
.github/ISSUE_TEMPLATE/new-trainer-addition.yml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "\U0001F31F New trainer addition"
|
2 |
+
description: Submit a proposal/request to implement a new trainer for a post-training method
|
3 |
+
labels: [ "New trainer" ]
|
4 |
+
|
5 |
+
body:
|
6 |
+
- type: textarea
|
7 |
+
id: description-request
|
8 |
+
validations:
|
9 |
+
required: true
|
10 |
+
attributes:
|
11 |
+
label: Method description
|
12 |
+
description: |
|
13 |
+
Put any and all important information relative to the method
|
14 |
+
|
15 |
+
- type: checkboxes
|
16 |
+
id: information-tasks
|
17 |
+
attributes:
|
18 |
+
label: Open source status
|
19 |
+
description: |
|
20 |
+
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
|
21 |
+
options:
|
22 |
+
- label: "The method implementation is available"
|
23 |
+
- label: "The model weights are available"
|
24 |
+
- label: "The training datasets are available"
|
25 |
+
|
26 |
+
- type: textarea
|
27 |
+
id: additional-info
|
28 |
+
attributes:
|
29 |
+
label: Provide useful links for the implementation
|
30 |
+
description: |
|
31 |
+
Please provide information regarding the implementation, the weights, and the authors.
|
32 |
+
Please mention the authors by @gh-username if you're aware of their usernames.
|
.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# What does this PR do?
|
2 |
+
|
3 |
+
<!--
|
4 |
+
Congratulations! You've made it this far! You're not quite done yet though.
|
5 |
+
|
6 |
+
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
|
7 |
+
|
8 |
+
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
|
9 |
+
|
10 |
+
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
|
11 |
+
-->
|
12 |
+
|
13 |
+
<!-- Remove if not applicable -->
|
14 |
+
|
15 |
+
Fixes # (issue)
|
16 |
+
|
17 |
+
|
18 |
+
## Before submitting
|
19 |
+
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
20 |
+
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
|
21 |
+
Pull Request section?
|
22 |
+
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
|
23 |
+
to it if that's the case.
|
24 |
+
- [ ] Did you make sure to update the documentation with your changes?
|
25 |
+
- [ ] Did you write any new necessary tests?
|
26 |
+
|
27 |
+
|
28 |
+
## Who can review?
|
29 |
+
|
30 |
+
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
31 |
+
members/contributors who may be interested in your PR.
|
.github/codeql/custom-queries.qls
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import codeql
|
2 |
+
|
3 |
+
from WorkflowString interpolation, Workflow workflow
|
4 |
+
where
|
5 |
+
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
|
6 |
+
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
|
7 |
+
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
|
8 |
+
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
|
9 |
+
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
|
10 |
+
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
|
11 |
+
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
|
12 |
+
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
|
13 |
+
interpolation.getStringValue().matches("${{ github.event.* }}") and
|
14 |
+
(
|
15 |
+
step.getKey() = "run" or // Injection in run
|
16 |
+
step.getKey() = "env" or // Injection via env
|
17 |
+
step.getKey() = "with" // Injection via with
|
18 |
+
)
|
19 |
+
select workflow, "🚨 Do not use directly as input of action"
|
.github/workflows/build_documentation.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Build documentation
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches:
|
6 |
+
- main
|
7 |
+
- doc-builder*
|
8 |
+
- v*-release
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
13 |
+
with:
|
14 |
+
commit_sha: ${{ github.sha }}
|
15 |
+
package: trl
|
16 |
+
version_tag_suffix: ""
|
17 |
+
custom_container: huggingface/transformers-doc-builder
|
18 |
+
secrets:
|
19 |
+
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
.github/workflows/build_pr_documentation.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Build PR Documentation
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
|
6 |
+
concurrency:
|
7 |
+
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
8 |
+
cancel-in-progress: true
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build:
|
12 |
+
if: github.event.pull_request.draft == false
|
13 |
+
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
14 |
+
with:
|
15 |
+
commit_sha: ${{ github.event.pull_request.head.sha }}
|
16 |
+
pr_number: ${{ github.event.number }}
|
17 |
+
package: trl
|
18 |
+
version_tag_suffix: ""
|
19 |
+
custom_container: huggingface/transformers-doc-builder
|
.github/workflows/clear_cache.yml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "Cleanup Cache"
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
schedule:
|
6 |
+
- cron: "0 0 * * *"
|
7 |
+
|
8 |
+
jobs:
|
9 |
+
cleanup:
|
10 |
+
runs-on: ubuntu-latest
|
11 |
+
steps:
|
12 |
+
- name: Check out code
|
13 |
+
uses: actions/checkout@v4
|
14 |
+
|
15 |
+
- name: Cleanup
|
16 |
+
run: |
|
17 |
+
gh extension install actions/gh-actions-cache
|
18 |
+
|
19 |
+
REPO=${{ github.repository }}
|
20 |
+
|
21 |
+
echo "Fetching list of cache key"
|
22 |
+
cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
|
23 |
+
|
24 |
+
## Setting this to not fail the workflow while deleting cache keys.
|
25 |
+
set +e
|
26 |
+
echo "Deleting caches..."
|
27 |
+
for cacheKey in $cacheKeysForPR
|
28 |
+
do
|
29 |
+
gh actions-cache delete $cacheKey -R $REPO --confirm
|
30 |
+
done
|
31 |
+
echo "Done"
|
32 |
+
env:
|
33 |
+
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
.github/workflows/codeQL.yml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "CodeQL Analysis - Workflows"
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
+
jobs:
|
7 |
+
analyze:
|
8 |
+
name: "Analyze GitHub Workflows"
|
9 |
+
runs-on: ubuntu-latest
|
10 |
+
permissions:
|
11 |
+
security-events: write
|
12 |
+
actions: read
|
13 |
+
contents: read
|
14 |
+
|
15 |
+
steps:
|
16 |
+
- name: "Checkout repository"
|
17 |
+
uses: actions/checkout@v4
|
18 |
+
|
19 |
+
- name: "Initialize CodeQL"
|
20 |
+
uses: github/codeql-action/init@v2
|
21 |
+
with:
|
22 |
+
languages: "yaml"
|
23 |
+
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
|
24 |
+
|
25 |
+
- name: "Perform CodeQL Analysis"
|
26 |
+
uses: github/codeql-action/analyze@v2
|
.github/workflows/docker-build.yml
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Build Docker images (scheduled)
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
workflow_call:
|
6 |
+
schedule:
|
7 |
+
- cron: "0 1 * * *"
|
8 |
+
|
9 |
+
concurrency:
|
10 |
+
group: docker-image-builds
|
11 |
+
cancel-in-progress: false
|
12 |
+
|
13 |
+
env:
|
14 |
+
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
15 |
+
|
16 |
+
jobs:
|
17 |
+
trl-latest:
|
18 |
+
name: "Latest TRL GPU"
|
19 |
+
runs-on: ubuntu-latest
|
20 |
+
steps:
|
21 |
+
- name: Cleanup disk
|
22 |
+
run: |
|
23 |
+
sudo ls -l /usr/local/lib/
|
24 |
+
sudo ls -l /usr/share/
|
25 |
+
sudo du -sh /usr/local/lib/
|
26 |
+
sudo du -sh /usr/share/
|
27 |
+
sudo rm -rf /usr/local/lib/android
|
28 |
+
sudo rm -rf /usr/share/dotnet
|
29 |
+
sudo du -sh /usr/local/lib/
|
30 |
+
sudo du -sh /usr/share/
|
31 |
+
- name: Set up Docker Buildx
|
32 |
+
uses: docker/setup-buildx-action@v1
|
33 |
+
- name: Check out code
|
34 |
+
uses: actions/checkout@v4
|
35 |
+
- name: Login to DockerHub
|
36 |
+
uses: docker/login-action@v1
|
37 |
+
with:
|
38 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
39 |
+
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
40 |
+
|
41 |
+
- name: Build and Push GPU
|
42 |
+
uses: docker/build-push-action@v4
|
43 |
+
with:
|
44 |
+
context: ./docker/trl-latest-gpu
|
45 |
+
push: true
|
46 |
+
tags: huggingface/trl-latest-gpu
|
47 |
+
|
48 |
+
- name: Post to Slack
|
49 |
+
if: always()
|
50 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
51 |
+
with:
|
52 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
53 |
+
title: 🤗 Results of the trl-latest-gpu Docker Image build
|
54 |
+
status: ${{ job.status }}
|
55 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
56 |
+
|
57 |
+
trl-source:
|
58 |
+
name: "Latest TRL + HF ecosystem from source"
|
59 |
+
runs-on: ubuntu-latest
|
60 |
+
steps:
|
61 |
+
- name: Cleanup disk
|
62 |
+
run: |
|
63 |
+
sudo ls -l /usr/local/lib/
|
64 |
+
sudo ls -l /usr/share/
|
65 |
+
sudo du -sh /usr/local/lib/
|
66 |
+
sudo du -sh /usr/share/
|
67 |
+
sudo rm -rf /usr/local/lib/android
|
68 |
+
sudo rm -rf /usr/share/dotnet
|
69 |
+
sudo du -sh /usr/local/lib/
|
70 |
+
sudo du -sh /usr/share/
|
71 |
+
- name: Set up Docker Buildx
|
72 |
+
uses: docker/setup-buildx-action@v1
|
73 |
+
- name: Check out code
|
74 |
+
uses: actions/checkout@v4
|
75 |
+
- name: Login to DockerHub
|
76 |
+
uses: docker/login-action@v1
|
77 |
+
with:
|
78 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
79 |
+
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
80 |
+
|
81 |
+
- name: Build and Push GPU
|
82 |
+
uses: docker/build-push-action@v4
|
83 |
+
with:
|
84 |
+
context: ./docker/trl-source-gpu
|
85 |
+
push: true
|
86 |
+
tags: huggingface/trl-source-gpu
|
87 |
+
|
88 |
+
- name: Post to Slack
|
89 |
+
if: always()
|
90 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
91 |
+
with:
|
92 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
93 |
+
title: 🤗 Results of the trl-source-gpu Docker Image build
|
94 |
+
status: ${{ job.status }}
|
95 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
.github/workflows/issue_auto_labeller.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: "Hugging Face Issue Labeler"
|
2 |
+
on:
|
3 |
+
issues:
|
4 |
+
types: opened
|
5 |
+
|
6 |
+
jobs:
|
7 |
+
triage:
|
8 |
+
runs-on: ubuntu-latest
|
9 |
+
permissions:
|
10 |
+
issues: write
|
11 |
+
steps:
|
12 |
+
- uses: actions/checkout@v3
|
13 |
+
- uses: August-murr/auto-labeler@main
|
14 |
+
with:
|
15 |
+
hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
|
.github/workflows/pr_style_bot.yml
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: PR Style Bot
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_dispatch:
|
5 |
+
|
6 |
+
|
7 |
+
permissions:
|
8 |
+
contents: write
|
9 |
+
pull-requests: write
|
10 |
+
|
11 |
+
jobs:
|
12 |
+
run-style-bot:
|
13 |
+
if: >
|
14 |
+
contains(github.event.comment.body, '@bot /style') &&
|
15 |
+
github.event.issue.pull_request != null
|
16 |
+
runs-on: ubuntu-latest
|
17 |
+
|
18 |
+
steps:
|
19 |
+
- name: Extract PR details
|
20 |
+
id: pr_info
|
21 |
+
uses: actions/github-script@v6
|
22 |
+
with:
|
23 |
+
script: |
|
24 |
+
const prNumber = context.payload.issue.number;
|
25 |
+
const { data: pr } = await github.rest.pulls.get({
|
26 |
+
owner: context.repo.owner,
|
27 |
+
repo: context.repo.repo,
|
28 |
+
pull_number: prNumber
|
29 |
+
});
|
30 |
+
|
31 |
+
// We capture both the branch ref and the "full_name" of the head repo
|
32 |
+
// so that we can check out the correct repository & branch (including forks).
|
33 |
+
core.setOutput("prNumber", prNumber);
|
34 |
+
core.setOutput("headRef", pr.head.ref);
|
35 |
+
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
36 |
+
|
37 |
+
- name: Check out PR branch
|
38 |
+
uses: actions/checkout@v3
|
39 |
+
env:
|
40 |
+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
41 |
+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
42 |
+
with:
|
43 |
+
# Instead of checking out the base repo, use the contributor's repo name
|
44 |
+
repository: ${{ env.HEADREPOFULLNAME }}
|
45 |
+
ref: ${{ env.HEADREF }}
|
46 |
+
# You may need fetch-depth: 0 for being able to push
|
47 |
+
fetch-depth: 0
|
48 |
+
token: ${{ secrets.GITHUB_TOKEN }}
|
49 |
+
|
50 |
+
- name: Debug
|
51 |
+
env:
|
52 |
+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
53 |
+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
54 |
+
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
55 |
+
run: |
|
56 |
+
echo "PR number: ${{ env.PRNUMBER }}"
|
57 |
+
echo "Head Ref: ${{ env.HEADREF }}"
|
58 |
+
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
|
59 |
+
|
60 |
+
- name: Set up Python
|
61 |
+
uses: actions/setup-python@v4
|
62 |
+
|
63 |
+
- name: Install dependencies
|
64 |
+
run: |
|
65 |
+
pip install ruff pre-commit
|
66 |
+
|
67 |
+
- name: Download Makefile from main branch
|
68 |
+
run: |
|
69 |
+
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
|
70 |
+
|
71 |
+
- name: Compare Makefiles
|
72 |
+
run: |
|
73 |
+
if ! diff -q main_Makefile Makefile; then
|
74 |
+
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
75 |
+
exit 1
|
76 |
+
fi
|
77 |
+
echo "No changes in Makefile. Proceeding..."
|
78 |
+
rm -rf main_Makefile
|
79 |
+
|
80 |
+
- name: Run make style and make quality
|
81 |
+
run: |
|
82 |
+
make precommit || true
|
83 |
+
|
84 |
+
- name: Commit and push changes
|
85 |
+
id: commit_and_push
|
86 |
+
env:
|
87 |
+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
88 |
+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
89 |
+
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
90 |
+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
91 |
+
run: |
|
92 |
+
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
|
93 |
+
# Configure git with the Actions bot user
|
94 |
+
git config user.name "github-actions[bot]"
|
95 |
+
git config user.email "github-actions[bot]@users.noreply.github.com"
|
96 |
+
|
97 |
+
# Make sure your 'origin' remote is set to the contributor's fork
|
98 |
+
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
|
99 |
+
|
100 |
+
# If there are changes after running style/quality, commit them
|
101 |
+
if [ -n "$(git status --porcelain)" ]; then
|
102 |
+
git add .
|
103 |
+
git commit -m "Apply style fixes"
|
104 |
+
# Push to the original contributor's forked branch
|
105 |
+
git push origin HEAD:${{ env.HEADREF }}
|
106 |
+
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
107 |
+
else
|
108 |
+
echo "No changes to commit."
|
109 |
+
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
110 |
+
fi
|
111 |
+
|
112 |
+
- name: Comment on PR with workflow run link
|
113 |
+
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
114 |
+
uses: actions/github-script@v6
|
115 |
+
with:
|
116 |
+
script: |
|
117 |
+
const prNumber = parseInt(process.env.prNumber, 10);
|
118 |
+
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
119 |
+
|
120 |
+
await github.rest.issues.createComment({
|
121 |
+
owner: context.repo.owner,
|
122 |
+
repo: context.repo.repo,
|
123 |
+
issue_number: prNumber,
|
124 |
+
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
125 |
+
});
|
126 |
+
env:
|
127 |
+
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
.github/workflows/slow-tests.yml
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Slow tests (on push)
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
paths:
|
7 |
+
# Run only when python files are modified
|
8 |
+
- "trl/**.py"
|
9 |
+
- "examples/**.py"
|
10 |
+
env:
|
11 |
+
RUN_SLOW: "yes"
|
12 |
+
IS_GITHUB_CI: "1"
|
13 |
+
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
14 |
+
|
15 |
+
|
16 |
+
jobs:
|
17 |
+
run_all_tests_single_gpu:
|
18 |
+
strategy:
|
19 |
+
fail-fast: false
|
20 |
+
matrix:
|
21 |
+
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
22 |
+
runs-on:
|
23 |
+
group: aws-g4dn-2xlarge
|
24 |
+
env:
|
25 |
+
CUDA_VISIBLE_DEVICES: "0"
|
26 |
+
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
|
27 |
+
container:
|
28 |
+
image: ${{ matrix.docker-image-name }}
|
29 |
+
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
30 |
+
defaults:
|
31 |
+
run:
|
32 |
+
shell: bash
|
33 |
+
steps:
|
34 |
+
- uses: actions/checkout@v4
|
35 |
+
- name: Pip install
|
36 |
+
run: |
|
37 |
+
source activate trl
|
38 |
+
pip install -e ".[test]" --no-deps
|
39 |
+
pip install pytest-reportlog parameterized
|
40 |
+
|
41 |
+
- name: Run slow SFT tests on single GPU
|
42 |
+
if: always()
|
43 |
+
run: |
|
44 |
+
source activate trl
|
45 |
+
make slow_tests
|
46 |
+
|
47 |
+
- name: Generate Report
|
48 |
+
if: always()
|
49 |
+
run: |
|
50 |
+
pip install slack_sdk tabulate
|
51 |
+
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
52 |
+
|
53 |
+
|
54 |
+
run_all_tests_multi_gpu:
|
55 |
+
strategy:
|
56 |
+
fail-fast: false
|
57 |
+
matrix:
|
58 |
+
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
59 |
+
runs-on:
|
60 |
+
group: aws-g4dn-2xlarge
|
61 |
+
env:
|
62 |
+
CUDA_VISIBLE_DEVICES: "0,1"
|
63 |
+
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
|
64 |
+
container:
|
65 |
+
image: ${{ matrix.docker-image-name }}
|
66 |
+
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
67 |
+
defaults:
|
68 |
+
run:
|
69 |
+
shell: bash
|
70 |
+
steps:
|
71 |
+
- uses: actions/checkout@v4
|
72 |
+
- name: Pip install
|
73 |
+
run: |
|
74 |
+
source activate trl
|
75 |
+
pip install -e ".[test]" --no-deps
|
76 |
+
pip install pytest-reportlog parameterized
|
77 |
+
|
78 |
+
- name: Run slow SFT tests on Multi GPU
|
79 |
+
if: always()
|
80 |
+
run: |
|
81 |
+
source activate trl
|
82 |
+
make slow_tests
|
83 |
+
|
84 |
+
- name: Run end-to-end examples tests on multi GPU
|
85 |
+
if: always()
|
86 |
+
run: |
|
87 |
+
source activate trl
|
88 |
+
pip install deepspeed
|
89 |
+
make test_examples
|
90 |
+
|
91 |
+
- name: Generate Reports
|
92 |
+
if: always()
|
93 |
+
run: |
|
94 |
+
pip install slack_sdk tabulate
|
95 |
+
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
96 |
+
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
97 |
+
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
98 |
+
rm *.txt
|
.github/workflows/tests.yml
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: [ main ]
|
6 |
+
pull_request:
|
7 |
+
paths:
|
8 |
+
# Run only when relevant files are modified
|
9 |
+
- ".github/**.yml"
|
10 |
+
- "examples/**.py"
|
11 |
+
- "scripts/**.py"
|
12 |
+
- "tests/**.py"
|
13 |
+
- "trl/**.py"
|
14 |
+
- "setup.py"
|
15 |
+
|
16 |
+
env:
|
17 |
+
TQDM_DISABLE: 1
|
18 |
+
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
19 |
+
|
20 |
+
jobs:
|
21 |
+
check_code_quality:
|
22 |
+
name: Check code quality
|
23 |
+
runs-on: ubuntu-latest
|
24 |
+
if: github.event.pull_request.draft == false
|
25 |
+
steps:
|
26 |
+
- uses: actions/checkout@v4
|
27 |
+
- name: Set up Python 3.12
|
28 |
+
uses: actions/setup-python@v5
|
29 |
+
with:
|
30 |
+
python-version: 3.12
|
31 |
+
- uses: pre-commit/[email protected]
|
32 |
+
with:
|
33 |
+
extra_args: --all-files
|
34 |
+
|
35 |
+
tests:
|
36 |
+
name: Tests
|
37 |
+
strategy:
|
38 |
+
matrix:
|
39 |
+
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
40 |
+
fail-fast: false
|
41 |
+
runs-on:
|
42 |
+
group: aws-g4dn-2xlarge
|
43 |
+
container:
|
44 |
+
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
45 |
+
options: --gpus all
|
46 |
+
defaults:
|
47 |
+
run:
|
48 |
+
shell: bash
|
49 |
+
if: github.event.pull_request.draft == false
|
50 |
+
steps:
|
51 |
+
- name: Git checkout
|
52 |
+
uses: actions/checkout@v4
|
53 |
+
|
54 |
+
- name: Set up Python ${{ matrix.python-version }}
|
55 |
+
uses: actions/setup-python@v5
|
56 |
+
with:
|
57 |
+
python-version: ${{ matrix.python-version }}
|
58 |
+
|
59 |
+
- name: Install Make and Git
|
60 |
+
run: |
|
61 |
+
apt-get update && apt-get install -y make git curl
|
62 |
+
|
63 |
+
- name: Install uv
|
64 |
+
run: |
|
65 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
66 |
+
|
67 |
+
- name: Create Python virtual environment
|
68 |
+
run: |
|
69 |
+
uv venv
|
70 |
+
uv pip install --upgrade setuptools wheel
|
71 |
+
|
72 |
+
- name: Install dependencies
|
73 |
+
run: |
|
74 |
+
source .venv/bin/activate
|
75 |
+
uv pip install ".[dev]"
|
76 |
+
|
77 |
+
- name: Test with pytest
|
78 |
+
run: |
|
79 |
+
source .venv/bin/activate
|
80 |
+
make test
|
81 |
+
|
82 |
+
- name: Post to Slack
|
83 |
+
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
84 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
85 |
+
with:
|
86 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
87 |
+
title: Results with Python ${{ matrix.python-version }} and latest dependencies
|
88 |
+
status: ${{ job.status }}
|
89 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
90 |
+
|
91 |
+
tests_dev:
|
92 |
+
name: Tests with dev dependencies
|
93 |
+
runs-on:
|
94 |
+
group: aws-g4dn-2xlarge
|
95 |
+
container:
|
96 |
+
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
97 |
+
options: --gpus all
|
98 |
+
defaults:
|
99 |
+
run:
|
100 |
+
shell: bash
|
101 |
+
if: github.event.pull_request.draft == false
|
102 |
+
steps:
|
103 |
+
- name: Git checkout
|
104 |
+
uses: actions/checkout@v4
|
105 |
+
|
106 |
+
- name: Set up Python 3.12
|
107 |
+
uses: actions/setup-python@v5
|
108 |
+
with:
|
109 |
+
python-version: '3.12'
|
110 |
+
|
111 |
+
- name: Install Make and Git
|
112 |
+
run: |
|
113 |
+
apt-get update && apt-get install -y make git curl
|
114 |
+
|
115 |
+
- name: Install uv
|
116 |
+
run: |
|
117 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
118 |
+
|
119 |
+
- name: Create Python virtual environment
|
120 |
+
run: |
|
121 |
+
uv venv
|
122 |
+
uv pip install --upgrade setuptools wheel
|
123 |
+
|
124 |
+
- name: Install dependencies
|
125 |
+
run: |
|
126 |
+
source .venv/bin/activate
|
127 |
+
uv pip install ".[dev]"
|
128 |
+
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
129 |
+
uv pip install -U git+https://github.com/huggingface/datasets.git
|
130 |
+
uv pip install -U git+https://github.com/huggingface/transformers.git
|
131 |
+
|
132 |
+
|
133 |
+
- name: Test with pytest
|
134 |
+
run: |
|
135 |
+
source .venv/bin/activate
|
136 |
+
make test
|
137 |
+
|
138 |
+
- name: Post to Slack
|
139 |
+
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
140 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
141 |
+
with:
|
142 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
143 |
+
title: Results with Python 3.12 and dev dependencies
|
144 |
+
status: ${{ job.status }}
|
145 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
146 |
+
|
147 |
+
tests_wo_optional_deps:
|
148 |
+
name: Tests without optional dependencies
|
149 |
+
runs-on:
|
150 |
+
group: aws-g4dn-2xlarge
|
151 |
+
container:
|
152 |
+
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
153 |
+
options: --gpus all
|
154 |
+
defaults:
|
155 |
+
run:
|
156 |
+
shell: bash
|
157 |
+
if: github.event.pull_request.draft == false
|
158 |
+
steps:
|
159 |
+
- name: Git checkout
|
160 |
+
uses: actions/checkout@v4
|
161 |
+
|
162 |
+
- name: Set up Python 3.12
|
163 |
+
uses: actions/setup-python@v5
|
164 |
+
with:
|
165 |
+
python-version: '3.12'
|
166 |
+
|
167 |
+
- name: Install Make and Git
|
168 |
+
run: |
|
169 |
+
apt-get update && apt-get install -y make git curl
|
170 |
+
|
171 |
+
- name: Install uv
|
172 |
+
run: |
|
173 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
174 |
+
|
175 |
+
- name: Create Python virtual environment
|
176 |
+
run: |
|
177 |
+
uv venv
|
178 |
+
uv pip install --upgrade setuptools wheel
|
179 |
+
|
180 |
+
- name: Install dependencies
|
181 |
+
run: |
|
182 |
+
source .venv/bin/activate
|
183 |
+
uv pip install ".[test]"
|
184 |
+
|
185 |
+
- name: Test with pytest
|
186 |
+
run: |
|
187 |
+
source .venv/bin/activate
|
188 |
+
make test
|
189 |
+
|
190 |
+
- name: Post to Slack
|
191 |
+
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
192 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
193 |
+
with:
|
194 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
195 |
+
title: Results with Python 3.12 without optional dependencies
|
196 |
+
status: ${{ job.status }}
|
197 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
198 |
+
|
199 |
+
tests_min_versions:
|
200 |
+
name: Tests with minimum versions
|
201 |
+
runs-on:
|
202 |
+
group: aws-g4dn-2xlarge
|
203 |
+
container:
|
204 |
+
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
205 |
+
options: --gpus all
|
206 |
+
defaults:
|
207 |
+
run:
|
208 |
+
shell: bash
|
209 |
+
if: github.event.pull_request.draft == false
|
210 |
+
steps:
|
211 |
+
- name: Git checkout
|
212 |
+
uses: actions/checkout@v4
|
213 |
+
|
214 |
+
- name: Set up Python 3.12
|
215 |
+
uses: actions/setup-python@v5
|
216 |
+
with:
|
217 |
+
python-version: '3.12'
|
218 |
+
|
219 |
+
- name: Install Make and Git
|
220 |
+
run: |
|
221 |
+
apt-get update && apt-get install -y make git curl
|
222 |
+
|
223 |
+
- name: Install uv
|
224 |
+
run: |
|
225 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
226 |
+
|
227 |
+
- name: Create Python virtual environment
|
228 |
+
run: |
|
229 |
+
uv venv
|
230 |
+
uv pip install --upgrade setuptools wheel
|
231 |
+
|
232 |
+
- name: Install dependencies
|
233 |
+
run: |
|
234 |
+
source .venv/bin/activate
|
235 |
+
uv pip install ".[dev]"
|
236 |
+
uv pip install accelerate==1.4.0
|
237 |
+
uv pip install datasets==3.0.0
|
238 |
+
uv pip install transformers==4.51.0
|
239 |
+
|
240 |
+
- name: Test with pytest
|
241 |
+
run: |
|
242 |
+
source .venv/bin/activate
|
243 |
+
make test
|
244 |
+
|
245 |
+
- name: Post to Slack
|
246 |
+
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
247 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
248 |
+
with:
|
249 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
250 |
+
title: Results with Python 3.12 and minimum dependencies versions
|
251 |
+
status: ${{ job.status }}
|
252 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
.github/workflows/tests_latest.yml
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Tests latest TRL release with dev dependencies
|
2 |
+
|
3 |
+
on:
|
4 |
+
schedule:
|
5 |
+
- cron: '0 0 * * *' # Runs daily at midnight UTC
|
6 |
+
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
env:
|
10 |
+
TQDM_DISABLE: 1
|
11 |
+
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
12 |
+
|
13 |
+
jobs:
|
14 |
+
tests:
|
15 |
+
name: Tests latest TRL release with dev dependencies
|
16 |
+
runs-on:
|
17 |
+
group: aws-g4dn-2xlarge
|
18 |
+
container:
|
19 |
+
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
20 |
+
options: --gpus all
|
21 |
+
defaults:
|
22 |
+
run:
|
23 |
+
shell: bash
|
24 |
+
steps:
|
25 |
+
- name: Git checkout
|
26 |
+
uses: actions/checkout@v4
|
27 |
+
with: { ref: v0.18-release }
|
28 |
+
|
29 |
+
- name: Set up Python 3.12
|
30 |
+
uses: actions/setup-python@v5
|
31 |
+
with:
|
32 |
+
python-version: '3.12'
|
33 |
+
|
34 |
+
- name: Install Make and Git
|
35 |
+
run: |
|
36 |
+
apt-get update && apt-get install -y make git curl
|
37 |
+
|
38 |
+
- name: Install uv
|
39 |
+
run: |
|
40 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
41 |
+
|
42 |
+
- name: Create Python virtual environment
|
43 |
+
run: |
|
44 |
+
uv venv
|
45 |
+
uv pip install --upgrade setuptools wheel
|
46 |
+
|
47 |
+
- name: Install dependencies
|
48 |
+
run: |
|
49 |
+
source .venv/bin/activate
|
50 |
+
uv pip install ".[dev]"
|
51 |
+
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
52 |
+
uv pip install -U git+https://github.com/huggingface/datasets.git
|
53 |
+
uv pip install -U git+https://github.com/huggingface/transformers.git
|
54 |
+
|
55 |
+
- name: Test with pytest
|
56 |
+
run: |
|
57 |
+
source .venv/bin/activate
|
58 |
+
make test
|
59 |
+
|
60 |
+
- name: Post to Slack
|
61 |
+
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
62 |
+
with:
|
63 |
+
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
64 |
+
title: Results of latest TRL with Python 3.12 and dev dependencies
|
65 |
+
status: ${{ job.status }}
|
66 |
+
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
.github/workflows/trufflehog.yml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
on:
|
2 |
+
push:
|
3 |
+
|
4 |
+
name: Secret Leaks
|
5 |
+
|
6 |
+
jobs:
|
7 |
+
trufflehog:
|
8 |
+
runs-on: ubuntu-latest
|
9 |
+
steps:
|
10 |
+
- name: Checkout code
|
11 |
+
uses: actions/checkout@v4
|
12 |
+
with:
|
13 |
+
fetch-depth: 0
|
14 |
+
- name: Secret Scanning
|
15 |
+
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
16 |
+
with:
|
17 |
+
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
18 |
+
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
.github/workflows/upload_pr_documentation.yml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Upload PR Documentation
|
2 |
+
|
3 |
+
on:
|
4 |
+
workflow_run:
|
5 |
+
workflows: ["Build PR Documentation"]
|
6 |
+
types:
|
7 |
+
- completed
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
build:
|
11 |
+
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
12 |
+
with:
|
13 |
+
package_name: trl
|
14 |
+
secrets:
|
15 |
+
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
16 |
+
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
.gitignore
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.bak
|
2 |
+
.last_checked
|
3 |
+
.gitconfig
|
4 |
+
*.bak
|
5 |
+
*.log
|
6 |
+
*~
|
7 |
+
~*
|
8 |
+
_tmp*
|
9 |
+
tmp*
|
10 |
+
tags
|
11 |
+
|
12 |
+
# Byte-compiled / optimized / DLL files
|
13 |
+
__pycache__/
|
14 |
+
*.py[cod]
|
15 |
+
*$py.class
|
16 |
+
|
17 |
+
# C extensions
|
18 |
+
*.so
|
19 |
+
|
20 |
+
# Distribution / packaging
|
21 |
+
.Python
|
22 |
+
env/
|
23 |
+
build/
|
24 |
+
develop-eggs/
|
25 |
+
dist/
|
26 |
+
downloads/
|
27 |
+
eggs/
|
28 |
+
.eggs/
|
29 |
+
lib/
|
30 |
+
lib64/
|
31 |
+
parts/
|
32 |
+
sdist/
|
33 |
+
var/
|
34 |
+
wheels/
|
35 |
+
*.egg-info/
|
36 |
+
.installed.cfg
|
37 |
+
*.egg
|
38 |
+
|
39 |
+
# PyInstaller
|
40 |
+
# Usually these files are written by a python script from a template
|
41 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
42 |
+
*.manifest
|
43 |
+
*.spec
|
44 |
+
|
45 |
+
# Installer logs
|
46 |
+
pip-log.txt
|
47 |
+
pip-delete-this-directory.txt
|
48 |
+
|
49 |
+
# Unit test / coverage reports
|
50 |
+
htmlcov/
|
51 |
+
.tox/
|
52 |
+
.coverage
|
53 |
+
.coverage.*
|
54 |
+
.cache
|
55 |
+
nosetests.xml
|
56 |
+
coverage.xml
|
57 |
+
*.cover
|
58 |
+
.hypothesis/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
|
68 |
+
# Flask stuff:
|
69 |
+
instance/
|
70 |
+
.webassets-cache
|
71 |
+
|
72 |
+
# Scrapy stuff:
|
73 |
+
.scrapy
|
74 |
+
|
75 |
+
# Sphinx documentation
|
76 |
+
docs/_build/
|
77 |
+
|
78 |
+
# PyBuilder
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# celery beat schedule file
|
88 |
+
celerybeat-schedule
|
89 |
+
|
90 |
+
# SageMath parsed files
|
91 |
+
*.sage.py
|
92 |
+
|
93 |
+
# dotenv
|
94 |
+
.env
|
95 |
+
|
96 |
+
# virtualenv
|
97 |
+
.venv
|
98 |
+
venv/
|
99 |
+
ENV/
|
100 |
+
|
101 |
+
# Spyder project settings
|
102 |
+
.spyderproject
|
103 |
+
.spyproject
|
104 |
+
|
105 |
+
# Rope project settings
|
106 |
+
.ropeproject
|
107 |
+
|
108 |
+
# mkdocs documentation
|
109 |
+
/site
|
110 |
+
|
111 |
+
# mypy
|
112 |
+
.mypy_cache/
|
113 |
+
|
114 |
+
.vscode
|
115 |
+
*.swp
|
116 |
+
|
117 |
+
# osx generated files
|
118 |
+
.DS_Store
|
119 |
+
.DS_Store?
|
120 |
+
.Trashes
|
121 |
+
ehthumbs.db
|
122 |
+
Thumbs.db
|
123 |
+
.idea
|
124 |
+
|
125 |
+
# pytest
|
126 |
+
.pytest_cache
|
127 |
+
|
128 |
+
# tools/trust-doc-nbs
|
129 |
+
docs_src/.last_checked
|
130 |
+
|
131 |
+
# symlinks to fastai
|
132 |
+
docs_src/fastai
|
133 |
+
tools/fastai
|
134 |
+
|
135 |
+
# link checker
|
136 |
+
checklink/cookies.txt
|
137 |
+
|
138 |
+
# .gitconfig is now autogenerated
|
139 |
+
.gitconfig
|
140 |
+
|
141 |
+
# wandb files
|
142 |
+
nbs/wandb/
|
143 |
+
examples/notebooks/wandb/
|
144 |
+
wandb/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
3 |
+
rev: v0.11.10
|
4 |
+
hooks:
|
5 |
+
- id: ruff-check
|
6 |
+
types_or: [ python, pyi ]
|
7 |
+
args: [ --fix ]
|
8 |
+
- id: ruff-format
|
9 |
+
types_or: [ python, pyi ]
|
10 |
+
|
11 |
+
# - repo: https://github.com/codespell-project/codespell
|
12 |
+
# rev: v2.1.0
|
13 |
+
# hooks:
|
14 |
+
# - id: codespell
|
15 |
+
# args:
|
16 |
+
# - --ignore-words-list=nd,reacher,thist,ths,magent,ba
|
17 |
+
# - --skip=docs/css/termynal.css,docs/js/termynal.js
|
CITATION.cff
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
title: 'TRL: Transformer Reinforcement Learning'
|
3 |
+
message: >-
|
4 |
+
If you use this software, please cite it using the
|
5 |
+
metadata from this file.
|
6 |
+
type: software
|
7 |
+
authors:
|
8 |
+
- given-names: Leandro
|
9 |
+
family-names: von Werra
|
10 |
+
- given-names: Younes
|
11 |
+
family-names: Belkada
|
12 |
+
- given-names: Lewis
|
13 |
+
family-names: Tunstall
|
14 |
+
- given-names: Edward
|
15 |
+
family-names: Beeching
|
16 |
+
- given-names: Tristan
|
17 |
+
family-names: Thrush
|
18 |
+
- given-names: Nathan
|
19 |
+
family-names: Lambert
|
20 |
+
- given-names: Shengyi
|
21 |
+
family-names: Huang
|
22 |
+
- given-names: Kashif
|
23 |
+
family-names: Rasul
|
24 |
+
- given-names: Quentin
|
25 |
+
family-names: Gallouédec
|
26 |
+
repository-code: 'https://github.com/huggingface/trl'
|
27 |
+
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
|
28 |
+
keywords:
|
29 |
+
- rlhf
|
30 |
+
- deep-learning
|
31 |
+
- pytorch
|
32 |
+
- transformers
|
33 |
+
license: Apache-2.0
|
34 |
+
version: 0.18
|
CODE_OF_CONDUCT.md
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Contributor Covenant Code of Conduct
|
3 |
+
|
4 |
+
## Our Pledge
|
5 |
+
|
6 |
+
We as members, contributors, and leaders pledge to make participation in our
|
7 |
+
community a harassment-free experience for everyone, regardless of age, body
|
8 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
9 |
+
identity and expression, level of experience, education, socio-economic status,
|
10 |
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
11 |
+
identity and orientation.
|
12 |
+
|
13 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
14 |
+
diverse, inclusive, and healthy community.
|
15 |
+
|
16 |
+
## Our Standards
|
17 |
+
|
18 |
+
Examples of behavior that contributes to a positive environment for our
|
19 |
+
community include:
|
20 |
+
|
21 |
+
* Demonstrating empathy and kindness toward other people
|
22 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
23 |
+
* Giving and gracefully accepting constructive feedback
|
24 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
25 |
+
and learning from the experience
|
26 |
+
* Focusing on what is best not just for us as individuals, but for the overall
|
27 |
+
community
|
28 |
+
|
29 |
+
Examples of unacceptable behavior include:
|
30 |
+
|
31 |
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
32 |
+
any kind
|
33 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
34 |
+
* Public or private harassment
|
35 |
+
* Publishing others' private information, such as a physical or email address,
|
36 |
+
without their explicit permission
|
37 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
38 |
+
professional setting
|
39 |
+
|
40 |
+
## Enforcement Responsibilities
|
41 |
+
|
42 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
43 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
44 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
45 |
+
or harmful.
|
46 |
+
|
47 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
48 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
49 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
50 |
+
decisions when appropriate.
|
51 |
+
|
52 |
+
## Scope
|
53 |
+
|
54 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
55 |
+
an individual is officially representing the community in public spaces.
|
56 |
+
Examples of representing our community include using an official e-mail address,
|
57 |
+
posting via an official social media account, or acting as an appointed
|
58 |
+
representative at an online or offline event.
|
59 |
+
|
60 |
+
## Enforcement
|
61 |
+
|
62 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
63 |
+
reported to the community leaders responsible for enforcement at
|
64 | |
65 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
66 |
+
|
67 |
+
All community leaders are obligated to respect the privacy and security of the
|
68 |
+
reporter of any incident.
|
69 |
+
|
70 |
+
## Enforcement Guidelines
|
71 |
+
|
72 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
73 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
74 |
+
|
75 |
+
### 1. Correction
|
76 |
+
|
77 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
78 |
+
unprofessional or unwelcome in the community.
|
79 |
+
|
80 |
+
**Consequence**: A private, written warning from community leaders, providing
|
81 |
+
clarity around the nature of the violation and an explanation of why the
|
82 |
+
behavior was inappropriate. A public apology may be requested.
|
83 |
+
|
84 |
+
### 2. Warning
|
85 |
+
|
86 |
+
**Community Impact**: A violation through a single incident or series of
|
87 |
+
actions.
|
88 |
+
|
89 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
90 |
+
interaction with the people involved, including unsolicited interaction with
|
91 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
92 |
+
includes avoiding interactions in community spaces as well as external channels
|
93 |
+
like social media. Violating these terms may lead to a temporary or permanent
|
94 |
+
ban.
|
95 |
+
|
96 |
+
### 3. Temporary Ban
|
97 |
+
|
98 |
+
**Community Impact**: A serious violation of community standards, including
|
99 |
+
sustained inappropriate behavior.
|
100 |
+
|
101 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
102 |
+
communication with the community for a specified period of time. No public or
|
103 |
+
private interaction with the people involved, including unsolicited interaction
|
104 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
105 |
+
Violating these terms may lead to a permanent ban.
|
106 |
+
|
107 |
+
### 4. Permanent Ban
|
108 |
+
|
109 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
110 |
+
standards, including sustained inappropriate behavior, harassment of an
|
111 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
112 |
+
|
113 |
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
114 |
+
community.
|
115 |
+
|
116 |
+
## Attribution
|
117 |
+
|
118 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
119 |
+
version 2.1, available at
|
120 |
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
121 |
+
|
122 |
+
Community Impact Guidelines were inspired by
|
123 |
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
124 |
+
|
125 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
126 |
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
127 |
+
[https://www.contributor-covenant.org/translations][translations].
|
128 |
+
|
129 |
+
[homepage]: https://www.contributor-covenant.org
|
130 |
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
131 |
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
132 |
+
[FAQ]: https://www.contributor-covenant.org/faq
|
133 |
+
[translations]: https://www.contributor-covenant.org/translations
|
CONTRIBUTING.md
ADDED
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to contribute to TRL?
|
2 |
+
|
3 |
+
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
4 |
+
contributions are not the only way to help the community. Answering questions, helping
|
5 |
+
others, and improving the documentation are also immensely valuable.
|
6 |
+
|
7 |
+
It also helps us if you spread the word! Reference the library in blog posts
|
8 |
+
about the awesome projects it made possible, shout out on Twitter every time it has
|
9 |
+
helped you, or simply ⭐️ the repository to say thank you.
|
10 |
+
|
11 |
+
However you choose to contribute, please be mindful and respect our
|
12 |
+
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
13 |
+
|
14 |
+
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
15 |
+
|
16 |
+
## Ways to contribute
|
17 |
+
|
18 |
+
There are several ways you can contribute to TRL:
|
19 |
+
|
20 |
+
* Fix outstanding issues with the existing code.
|
21 |
+
* Submit issues related to bugs or desired new features.
|
22 |
+
* Implement trainers for new post-training algorithms.
|
23 |
+
* Contribute to the examples or the documentation.
|
24 |
+
|
25 |
+
If you don't know where to start, there is a special [Good First
|
26 |
+
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
|
27 |
+
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
28 |
+
|
29 |
+
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
30 |
+
|
31 |
+
> All contributions are equally valuable to the community. 🥰
|
32 |
+
|
33 |
+
Before you start contributing make sure you have installed all the dev tools:
|
34 |
+
|
35 |
+
```bash
|
36 |
+
pip install -e .[dev]
|
37 |
+
```
|
38 |
+
|
39 |
+
## Fixing outstanding issues
|
40 |
+
|
41 |
+
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
|
42 |
+
|
43 |
+
## Submitting a bug-related issue or feature request
|
44 |
+
|
45 |
+
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
|
46 |
+
|
47 |
+
### Did you find a bug?
|
48 |
+
|
49 |
+
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
50 |
+
|
51 |
+
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
52 |
+
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
53 |
+
|
54 |
+
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
55 |
+
|
56 |
+
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
57 |
+
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
58 |
+
less than 30s.
|
59 |
+
* The *full* traceback if an exception is raised.
|
60 |
+
* Attach any other additional information, like screenshots, you think may help.
|
61 |
+
|
62 |
+
To get the OS and software versions automatically, run the following command:
|
63 |
+
|
64 |
+
```bash
|
65 |
+
trl env
|
66 |
+
```
|
67 |
+
|
68 |
+
### Do you want a new feature?
|
69 |
+
|
70 |
+
If there is a new feature you'd like to see in TRL, please open an issue and describe:
|
71 |
+
|
72 |
+
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
|
73 |
+
|
74 |
+
Whatever it is, we'd love to hear about it!
|
75 |
+
|
76 |
+
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
|
77 |
+
3. Provide a *code snippet* that demonstrates the feature's usage.
|
78 |
+
4. If the feature is related to a paper, please include a link.
|
79 |
+
|
80 |
+
If your issue is well written we're already 80% of the way there by the time you create it.
|
81 |
+
|
82 |
+
## Do you want to implement a new trainer?
|
83 |
+
|
84 |
+
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
|
85 |
+
|
86 |
+
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
|
87 |
+
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
|
88 |
+
|
89 |
+
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
90 |
+
|
91 |
+
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
|
92 |
+
|
93 |
+
* A short description of the method and a link to the paper.
|
94 |
+
* Link to the implementation if it is open-sourced.
|
95 |
+
* Link to model weights trained with the method if they are available.
|
96 |
+
|
97 |
+
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
|
98 |
+
|
99 |
+
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
|
100 |
+
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
|
101 |
+
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
|
102 |
+
|
103 |
+
## Do you want to add documentation?
|
104 |
+
|
105 |
+
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
|
106 |
+
|
107 |
+
## Submitting a pull request (PR)
|
108 |
+
|
109 |
+
Before writing code, we strongly advise you to search through the existing PRs or
|
110 |
+
issues to make sure that nobody is already working on the same thing. If you are
|
111 |
+
unsure, it is always a good idea to open an issue to get some feedback.
|
112 |
+
|
113 |
+
You will need basic `git` proficiency to be able to contribute to
|
114 |
+
TRL. `git` is not the easiest tool to use but it has the greatest
|
115 |
+
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
116 |
+
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
117 |
+
|
118 |
+
Follow these steps to start contributing:
|
119 |
+
|
120 |
+
1. Fork the [repository](https://github.com/huggingface/trl) by
|
121 |
+
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
122 |
+
under your GitHub user account.
|
123 |
+
|
124 |
+
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
125 |
+
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
126 |
+
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
127 |
+
|
128 |
+
```bash
|
129 |
+
$ git clone [email protected]:<your Github handle>/trl.git
|
130 |
+
$ cd trl
|
131 |
+
$ git remote add upstream https://github.com/huggingface/trl.git
|
132 |
+
```
|
133 |
+
|
134 |
+
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
135 |
+
|
136 |
+
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
137 |
+
|
138 |
+
```bash
|
139 |
+
$ git checkout main
|
140 |
+
$ git fetch upstream
|
141 |
+
$ git merge upstream/main
|
142 |
+
```
|
143 |
+
|
144 |
+
Once your `main` branch is synchronized, create a new branch from it:
|
145 |
+
|
146 |
+
```bash
|
147 |
+
$ git checkout -b a-descriptive-name-for-my-changes
|
148 |
+
```
|
149 |
+
|
150 |
+
**Do not** work on the `main` branch.
|
151 |
+
|
152 |
+
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
153 |
+
|
154 |
+
```bash
|
155 |
+
$ pip install -e .[dev]
|
156 |
+
```
|
157 |
+
|
158 |
+
(If TRL was already installed in the virtual environment, remove
|
159 |
+
it with `pip uninstall trl` before reinstalling it.)
|
160 |
+
|
161 |
+
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
162 |
+
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
163 |
+
|
164 |
+
5. Develop the features on your branch.
|
165 |
+
|
166 |
+
As you work on the features, you should make sure that the test suite
|
167 |
+
passes. You should run the tests impacted by your changes like this (see
|
168 |
+
below an explanation regarding the environment variable):
|
169 |
+
|
170 |
+
```bash
|
171 |
+
$ pytest tests/<TEST_TO_RUN>.py
|
172 |
+
```
|
173 |
+
|
174 |
+
> For the following commands leveraging the `make` utility.
|
175 |
+
|
176 |
+
You can also run the full suite with the following command.
|
177 |
+
|
178 |
+
```bash
|
179 |
+
$ make test
|
180 |
+
```
|
181 |
+
|
182 |
+
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
183 |
+
|
184 |
+
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
|
185 |
+
|
186 |
+
To apply these checks and corrections in one step, use:
|
187 |
+
|
188 |
+
```bash
|
189 |
+
$ make precommit
|
190 |
+
```
|
191 |
+
|
192 |
+
This command runs the following:
|
193 |
+
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
194 |
+
- Runs additional scripts such as adding copyright information.
|
195 |
+
|
196 |
+
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
197 |
+
|
198 |
+
Once you're happy with your changes, add changed files using `git add` and
|
199 |
+
make a commit with `git commit` to record your changes locally:
|
200 |
+
|
201 |
+
```bash
|
202 |
+
$ git add modified_file.py
|
203 |
+
$ git commit
|
204 |
+
```
|
205 |
+
|
206 |
+
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
207 |
+
|
208 |
+
It is a good idea to sync your copy of the code with the original
|
209 |
+
repository regularly. This way you can quickly account for changes:
|
210 |
+
|
211 |
+
```bash
|
212 |
+
$ git fetch upstream
|
213 |
+
$ git rebase upstream/main
|
214 |
+
```
|
215 |
+
|
216 |
+
Push the changes to your account using:
|
217 |
+
|
218 |
+
```bash
|
219 |
+
$ git push -u origin a-descriptive-name-for-my-changes
|
220 |
+
```
|
221 |
+
|
222 |
+
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
223 |
+
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
224 |
+
to the project maintainers for review.
|
225 |
+
|
226 |
+
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
227 |
+
|
228 |
+
|
229 |
+
### Checklist
|
230 |
+
|
231 |
+
1. The title of your pull request should be a summary of its contribution;
|
232 |
+
2. If your pull request addresses an issue, please mention the issue number in
|
233 |
+
the pull request description to make sure they are linked (and people
|
234 |
+
consulting the issue know you are working on it);
|
235 |
+
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
236 |
+
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
237 |
+
it from PRs ready to be merged;
|
238 |
+
4. Make sure existing tests pass;
|
239 |
+
5. Add high-coverage tests. No quality testing = no merge.
|
240 |
+
|
241 |
+
|
242 |
+
### Tests
|
243 |
+
|
244 |
+
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
245 |
+
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
|
246 |
+
|
247 |
+
We use `pytest` to run the tests. From the root of the
|
248 |
+
repository here's how to run tests with `pytest` for the library:
|
249 |
+
|
250 |
+
```bash
|
251 |
+
$ python -m pytest -sv ./tests
|
252 |
+
```
|
253 |
+
|
254 |
+
That's how `make test` is implemented (without the `pip install` line)!
|
255 |
+
|
256 |
+
You can specify a smaller set of tests to test only the feature
|
257 |
+
you're working on.
|
258 |
+
|
259 |
+
### Default values guidelines
|
260 |
+
|
261 |
+
1. **Use defaults when appropriate**:
|
262 |
+
|
263 |
+
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
264 |
+
|
265 |
+
2. **Prioritize proven defaults**:
|
266 |
+
|
267 |
+
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
268 |
+
|
269 |
+
3. **Ensure safety and predictability**:
|
270 |
+
|
271 |
+
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
272 |
+
|
273 |
+
4. **Balance consistency and flexibility**:
|
274 |
+
|
275 |
+
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
276 |
+
|
277 |
+
5. **Opt-in for new features**:
|
278 |
+
|
279 |
+
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
280 |
+
|
281 |
+
### Writing documentation
|
282 |
+
|
283 |
+
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
|
284 |
+
|
285 |
+
To illustrate what good documentation looks like, here’s an example of a well-documented function:
|
286 |
+
|
287 |
+
````python
|
288 |
+
def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
289 |
+
r"""
|
290 |
+
Replicate a string `n` times with a separator.
|
291 |
+
|
292 |
+
Args:
|
293 |
+
string (`str`):
|
294 |
+
String to replicate.
|
295 |
+
n (`int`):
|
296 |
+
Number of times to replicate the string.
|
297 |
+
sep (`str`, *optional*, defaults to `" "`):
|
298 |
+
Separator to use between each replication.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
`str`: The replicated string.
|
302 |
+
|
303 |
+
Examples:
|
304 |
+
```python
|
305 |
+
>>> replicate_str("hello", 3)
|
306 |
+
"hello hello hello"
|
307 |
+
>>> replicate_str("hello", 3, sep=", ")
|
308 |
+
"hello, hello, hello"
|
309 |
+
```
|
310 |
+
"""
|
311 |
+
return sep.join([string] * n)
|
312 |
+
````
|
313 |
+
|
314 |
+
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
|
315 |
+
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
316 |
+
* **Type Annotations:**
|
317 |
+
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
318 |
+
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
319 |
+
E.g., for arguments that can't be `None` and aren't required:
|
320 |
+
|
321 |
+
```python
|
322 |
+
foo (`int`, *optional*, defaults to `4`):
|
323 |
+
```
|
324 |
+
|
325 |
+
For arguments that can be `None` and are required:
|
326 |
+
|
327 |
+
```python
|
328 |
+
foo (`Optional[int]`):
|
329 |
+
```
|
330 |
+
|
331 |
+
for arguments that can be `None` and aren't required:
|
332 |
+
|
333 |
+
```python
|
334 |
+
foo (`Optional[int]`, *optional*, defaults to `None`):
|
335 |
+
```
|
336 |
+
|
337 |
+
* **String Defaults:**
|
338 |
+
* Ensured that default string values are wrapped in double quotes:
|
339 |
+
|
340 |
+
```python
|
341 |
+
defaults to `"foo"`
|
342 |
+
```
|
343 |
+
|
344 |
+
* **Dictionary Typing:**
|
345 |
+
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
|
346 |
+
* **Default Value Formatting:**
|
347 |
+
* Consistently surrounded default values with backticks for improved formatting:
|
348 |
+
|
349 |
+
```python
|
350 |
+
defaults to `4`
|
351 |
+
```
|
352 |
+
|
353 |
+
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
|
354 |
+
|
355 |
+
```python
|
356 |
+
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
|
357 |
+
r"""
|
358 |
+
Calculates basic statistics for a given dataset.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
> Data inputs
|
362 |
+
|
363 |
+
data (`list[float]`):
|
364 |
+
A list of numerical values to analyze.
|
365 |
+
|
366 |
+
> Configuration parameters
|
367 |
+
|
368 |
+
precision (`int`, *optional*, defaults to `2`):
|
369 |
+
Number of decimal places to round the results.
|
370 |
+
include_variance (`bool`, *optional*, defaults to `False`):
|
371 |
+
Whether to include the variance of the dataset in the results.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
`dict[str, float]`:
|
375 |
+
A dictionary containing calculated statistics such as mean, median, and optionally variance.
|
376 |
+
"""
|
377 |
+
...
|
378 |
+
```
|
379 |
+
|
380 |
+
### Deprecation and backward compatibility
|
381 |
+
|
382 |
+
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
383 |
+
|
384 |
+
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
385 |
+
|
386 |
+
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
387 |
+
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
388 |
+
|
389 |
+
Example:
|
390 |
+
|
391 |
+
```python
|
392 |
+
warnings.warn(
|
393 |
+
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
394 |
+
"Please use the `Trainer.bar` class instead.",
|
395 |
+
FutureWarning,
|
396 |
+
)
|
397 |
+
```
|
398 |
+
|
399 |
+
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
400 |
+
|
401 |
+
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
402 |
+
|
403 |
+
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
404 |
+
|
405 |
+
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
406 |
+
|
407 |
+
### Working with warnings
|
408 |
+
|
409 |
+
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
|
410 |
+
|
411 |
+
#### Definitions
|
412 |
+
|
413 |
+
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
414 |
+
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
415 |
+
|
416 |
+
#### Choosing the right message
|
417 |
+
|
418 |
+
- **Correct → No warning**:
|
419 |
+
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
420 |
+
|
421 |
+
- **Correct but deserves attention → No warning, possibly a log message**:
|
422 |
+
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
423 |
+
|
424 |
+
```python
|
425 |
+
logger.info("This is an informational message about a rare but correct operation.")
|
426 |
+
```
|
427 |
+
|
428 |
+
- **Correct but very likely a mistake → Warning with option to disable**:
|
429 |
+
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
430 |
+
|
431 |
+
```python
|
432 |
+
def my_function(foo, bar, _warn=True):
|
433 |
+
if foo == bar:
|
434 |
+
if _warn:
|
435 |
+
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
|
436 |
+
# Do something
|
437 |
+
```
|
438 |
+
|
439 |
+
- **Supported but not correct → Warning**:
|
440 |
+
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
441 |
+
|
442 |
+
```python
|
443 |
+
def my_function(foo, bar):
|
444 |
+
if foo and bar:
|
445 |
+
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
|
446 |
+
# Do something
|
447 |
+
```
|
448 |
+
|
449 |
+
- **Not supported → Exception**:
|
450 |
+
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
451 |
+
|
452 |
+
```python
|
453 |
+
def my_function(foo, bar):
|
454 |
+
if foo and bar:
|
455 |
+
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
|
456 |
+
```
|
457 |
+
|
458 |
+
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
459 |
+
|
460 |
+
|
461 |
+
## Making a release
|
462 |
+
|
463 |
+
> [!NOTE]
|
464 |
+
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
465 |
+
|
466 |
+
#### 0. Prerequisites
|
467 |
+
|
468 |
+
- Dependencies:
|
469 |
+
- twine: `pip install build twine`
|
470 |
+
- Create an account in (and join the `trl` project):
|
471 |
+
- PyPI: https://pypi.org/
|
472 |
+
- Test PyPI: https://test.pypi.org/
|
473 |
+
|
474 |
+
### Major/Minor Release
|
475 |
+
|
476 |
+
#### 1. Ensure your local repository is up to date with the upstream repository
|
477 |
+
|
478 |
+
```bash
|
479 |
+
git checkout main
|
480 |
+
git pull origin main
|
481 |
+
```
|
482 |
+
|
483 |
+
> [!WARNING]
|
484 |
+
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
485 |
+
|
486 |
+
#### 2. Create a release branch from main
|
487 |
+
|
488 |
+
```bash
|
489 |
+
git checkout -b release-v{major}.{minor}
|
490 |
+
```
|
491 |
+
|
492 |
+
#### 3. Change the version in the following files
|
493 |
+
|
494 |
+
- `.github/workflows/tests_latest.yml`:
|
495 |
+
```diff
|
496 |
+
- with: { ref: v{major}.{minor-1}-release }
|
497 |
+
+ with: { ref: v{major}.{minor}-release }
|
498 |
+
```
|
499 |
+
- `CITATION.cff`
|
500 |
+
```diff
|
501 |
+
- version: {major}.{minor-1}
|
502 |
+
+ version: {major}.{minor}
|
503 |
+
```
|
504 |
+
- `trl/__init__.py`
|
505 |
+
```diff
|
506 |
+
- __version__ = "{major}.{minor}.0.dev0"
|
507 |
+
+ __version__ = "{major}.{minor}.0"
|
508 |
+
```
|
509 |
+
- `setup.cfg`
|
510 |
+
```diff
|
511 |
+
- version = {major}.{minor}.0.dev0
|
512 |
+
+ version = {major}.{minor}.0
|
513 |
+
```
|
514 |
+
|
515 |
+
#### 4. Commit and push these changes
|
516 |
+
|
517 |
+
```shell
|
518 |
+
git add .github/workflows/tests_latest.yml CITATION.cff trl/__init__.py setup.cfg
|
519 |
+
git commit -m 'Release: {major}.{minor}'
|
520 |
+
git push origin release-v{major}.{minor}
|
521 |
+
```
|
522 |
+
|
523 |
+
#### 5. Create a pull request
|
524 |
+
|
525 |
+
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
526 |
+
|
527 |
+
#### 6. Once the pull request is approved, merge it into `main`
|
528 |
+
|
529 |
+
#### 7. Add a tag in git to mark the release
|
530 |
+
|
531 |
+
```shell
|
532 |
+
git checkout main
|
533 |
+
git pull origin main
|
534 |
+
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
535 |
+
git push origin v{major}.{minor}.0
|
536 |
+
```
|
537 |
+
|
538 |
+
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
|
539 |
+
|
540 |
+
```shell
|
541 |
+
git checkout -b v{major}.{minor}-release
|
542 |
+
git push origin v{major}.{minor}-release
|
543 |
+
```
|
544 |
+
|
545 |
+
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
546 |
+
|
547 |
+
#### 9. Create the wheels for your release
|
548 |
+
|
549 |
+
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
|
550 |
+
|
551 |
+
Clean previous builds:
|
552 |
+
|
553 |
+
```shell
|
554 |
+
rm -rf build dist
|
555 |
+
```
|
556 |
+
|
557 |
+
At the root of your repo, run
|
558 |
+
|
559 |
+
```bash
|
560 |
+
python -m build .
|
561 |
+
```
|
562 |
+
|
563 |
+
This will create a folders named `dist` with the new versions of your package.
|
564 |
+
|
565 |
+
#### 10. Upload the package to PyPI Test
|
566 |
+
|
567 |
+
> [!IMPORTANT]
|
568 |
+
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
|
569 |
+
|
570 |
+
```shell
|
571 |
+
twine upload dist/* -r testpypi
|
572 |
+
```
|
573 |
+
|
574 |
+
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
|
575 |
+
|
576 |
+
```bash
|
577 |
+
pip install -i https://test.pypi.org/simple/ trl
|
578 |
+
```
|
579 |
+
|
580 |
+
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
|
581 |
+
|
582 |
+
```bash
|
583 |
+
pip install trl
|
584 |
+
pip uninstall trl
|
585 |
+
```
|
586 |
+
|
587 |
+
(the second line will remove trl but keep all its dependencies).
|
588 |
+
|
589 |
+
Also make sure you can actually use the package! Run the following line:
|
590 |
+
|
591 |
+
```bash
|
592 |
+
python -c "from trl import *"
|
593 |
+
```
|
594 |
+
|
595 |
+
along with anything that tests:
|
596 |
+
|
597 |
+
- the core feature of your package
|
598 |
+
- the new features you’re adding in the release
|
599 |
+
|
600 |
+
#### 11. Publish on PyPI
|
601 |
+
|
602 |
+
> [!WARNING]
|
603 |
+
> This can't be reverted. Make sure you have tested everything before doing this step.
|
604 |
+
|
605 |
+
```shell
|
606 |
+
twine upload dist/*
|
607 |
+
```
|
608 |
+
|
609 |
+
#### 12. Create a GitHub Release
|
610 |
+
|
611 |
+
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
612 |
+
2. Click **Draft a new release**.
|
613 |
+
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
614 |
+
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
615 |
+
5. Click **Publish Release**.
|
616 |
+
|
617 |
+
#### 13. Bump to dev version
|
618 |
+
|
619 |
+
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
620 |
+
|
621 |
+
```shell
|
622 |
+
git checkout -b bump-dev-version-{major}.{minor+1}
|
623 |
+
```
|
624 |
+
|
625 |
+
2. Change the version in the following files:
|
626 |
+
1. `trl/__init__.py`
|
627 |
+
```diff
|
628 |
+
- __version__ = "{major}.{minor}.0"
|
629 |
+
+ __version__ = "{major}.{minor+1}.0.dev0"
|
630 |
+
```
|
631 |
+
2. `setup.cfg`
|
632 |
+
```diff
|
633 |
+
- version = {major}.{minor}.0
|
634 |
+
+ version = {major}.{minor+1}.0.dev0
|
635 |
+
```
|
636 |
+
|
637 |
+
3. Commit and push these changes
|
638 |
+
|
639 |
+
```shell
|
640 |
+
git add trl/__init__.py setup.cfg
|
641 |
+
git commit -m '⬆️ Bump dev version'
|
642 |
+
git push origin bump-dev-version-{major}.{minor+1}
|
643 |
+
```
|
644 |
+
|
645 |
+
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
646 |
+
|
647 |
+
5. Once the pull request is approved, merge it into `main`.
|
648 |
+
|
649 |
+
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
650 |
+
|
651 |
+
|
652 |
+
## Making a patch release
|
653 |
+
|
654 |
+
#### 1. Ensure your local repository is up to date with the upstream repository
|
655 |
+
|
656 |
+
```bash
|
657 |
+
git checkout v{major}.{minor}-release
|
658 |
+
git pull origin main
|
659 |
+
```
|
660 |
+
|
661 |
+
#### 2. Cherry-pick the changes you want to include in the patch release
|
662 |
+
|
663 |
+
```bash
|
664 |
+
git cherry-pick <commit-hash-0>
|
665 |
+
git cherry-pick <commit-hash-1>
|
666 |
+
...
|
667 |
+
```
|
668 |
+
|
669 |
+
#### 3. Change the version in the following files
|
670 |
+
|
671 |
+
- `trl/__init__.py`
|
672 |
+
```diff
|
673 |
+
- __version__ = "{major}.{minor}.{patch-1}"
|
674 |
+
+ __version__ = "{major}.{minor}.{patch}"
|
675 |
+
```
|
676 |
+
- `setup.cfg`
|
677 |
+
```diff
|
678 |
+
- version = {major}.{minor}.{patch-1}
|
679 |
+
+ version = {major}.{minor}.{patch}
|
680 |
+
```
|
681 |
+
|
682 |
+
#### 4. Commit and push these changes
|
683 |
+
|
684 |
+
```shell
|
685 |
+
git add trl/__init__.py setup.cfg
|
686 |
+
git commit -m 'Release: {major}.{minor}.{patch}'
|
687 |
+
git push origin v{major}.{minor}-release
|
688 |
+
```
|
689 |
+
|
690 |
+
#### 5. Wait for the CI to pass
|
691 |
+
|
692 |
+
#### 6. Add a tag in git to mark the release
|
693 |
+
|
694 |
+
```shell
|
695 |
+
git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
|
696 |
+
git push origin v{major}.{minor}.{patch}
|
697 |
+
```
|
698 |
+
|
699 |
+
#### 7. Create the wheels for your release
|
700 |
+
|
701 |
+
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
|
702 |
+
|
703 |
+
Clean previous builds:
|
704 |
+
|
705 |
+
```shell
|
706 |
+
rm -rf build dist
|
707 |
+
```
|
708 |
+
|
709 |
+
At the root of your repo, run
|
710 |
+
|
711 |
+
```bash
|
712 |
+
python -m build .
|
713 |
+
```
|
714 |
+
|
715 |
+
This will create a folders named `dist` with the new versions of your package.
|
716 |
+
|
717 |
+
#### 8. Upload the package to PyPI Test
|
718 |
+
|
719 |
+
> [!IMPORTANT]
|
720 |
+
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
|
721 |
+
|
722 |
+
```shell
|
723 |
+
twine upload dist/* -r testpypi
|
724 |
+
```
|
725 |
+
|
726 |
+
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
|
727 |
+
|
728 |
+
```bash
|
729 |
+
pip install -i https://test.pypi.org/simple/ trl
|
730 |
+
```
|
731 |
+
|
732 |
+
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
|
733 |
+
|
734 |
+
```bash
|
735 |
+
pip install trl
|
736 |
+
pip uninstall trl
|
737 |
+
```
|
738 |
+
|
739 |
+
(the second line will remove trl but keep all its dependencies).
|
740 |
+
|
741 |
+
Also make sure you can actually use the package! Run the following line:
|
742 |
+
|
743 |
+
```bash
|
744 |
+
python -c "from trl import *"
|
745 |
+
```
|
746 |
+
|
747 |
+
along with anything that tests:
|
748 |
+
|
749 |
+
- the core feature of your package
|
750 |
+
- the new features you’re adding in the release
|
751 |
+
|
752 |
+
#### 9. Publish on PyPI
|
753 |
+
|
754 |
+
> [!WARNING]
|
755 |
+
> This can't be reverted. Make sure you have tested everything before doing this step.
|
756 |
+
|
757 |
+
```shell
|
758 |
+
twine upload dist/*
|
759 |
+
```
|
760 |
+
|
761 |
+
#### 10. Create a GitHub Release
|
762 |
+
|
763 |
+
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
764 |
+
2. Click **Draft a new release**.
|
765 |
+
3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
|
766 |
+
4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new.
|
767 |
+
5. Click **Publish Release**.
|
Dockerfile
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://huggingface.co/docs/hub/spaces-dev-mode#docker-spaces
|
2 |
+
|
3 |
+
FROM python:3.13-bookworm
|
4 |
+
|
5 |
+
RUN apt-get update
|
6 |
+
RUN apt-get install -y \
|
7 |
+
bash \
|
8 |
+
curl \
|
9 |
+
git \
|
10 |
+
git-lfs \
|
11 |
+
htop \
|
12 |
+
procps \
|
13 |
+
nano \
|
14 |
+
vim \
|
15 |
+
wget
|
16 |
+
RUN rm -fr /var/lib/apt/lists/*
|
17 |
+
|
18 |
+
RUN useradd -m -u 1000 user
|
19 |
+
|
20 |
+
WORKDIR /app
|
21 |
+
RUN chown user /app
|
22 |
+
RUN chmod 755 /app
|
23 |
+
|
24 |
+
USER user
|
25 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
26 |
+
RUN curl -fsSL https://pyenv.run | bash
|
27 |
+
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
|
28 |
+
|
29 |
+
COPY --chown=user . /app
|
30 |
+
|
31 |
+
RUN ls -la /app
|
32 |
+
|
33 |
+
RUN uv sync
|
34 |
+
|
35 |
+
# `7860` is the default port for Hugging Face Spaces running on Docker
|
36 |
+
# https://huggingface.co/docs/hub/en/spaces-config-reference
|
37 |
+
CMD ["python", "-m", "http.server", "--directory", "public", "7860"]
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2020-2025 The HuggingFace Team
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
MANIFEST.in
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include LICENSE
|
2 |
+
include CONTRIBUTING.md
|
3 |
+
include README.md
|
4 |
+
recursive-exclude * __pycache__
|
5 |
+
include trl/templates/*.md
|
6 |
+
include trl/accelerate_configs/*.yaml
|
Makefile
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu
|
2 |
+
|
3 |
+
check_dirs := examples tests trl
|
4 |
+
|
5 |
+
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
6 |
+
COMMAND_FILES_PATH = `pwd`/commands
|
7 |
+
|
8 |
+
test:
|
9 |
+
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
10 |
+
|
11 |
+
precommit:
|
12 |
+
python scripts/add_copyrights.py
|
13 |
+
pre-commit run --all-files
|
14 |
+
|
15 |
+
slow_tests:
|
16 |
+
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
17 |
+
|
18 |
+
test_examples:
|
19 |
+
touch temp_results_sft_tests.txt
|
20 |
+
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
21 |
+
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
22 |
+
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
23 |
+
done
|
24 |
+
|
25 |
+
touch temp_results_dpo_tests.txt
|
26 |
+
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
27 |
+
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
28 |
+
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
29 |
+
done
|
README.md
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Trl
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
---
|
9 |
+
|
10 |
+
# TRL - Transformer Reinforcement Learning
|
11 |
+
|
12 |
+
<div style="text-align: center">
|
13 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
|
14 |
+
</div>
|
15 |
+
|
16 |
+
<hr> <br>
|
17 |
+
|
18 |
+
<h3 align="center">
|
19 |
+
<p>A comprehensive library to post-train foundation models</p>
|
20 |
+
</h3>
|
21 |
+
|
22 |
+
<p align="center">
|
23 |
+
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
24 |
+
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
25 |
+
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
26 |
+
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
27 |
+
</p>
|
28 |
+
|
29 |
+
## Overview
|
30 |
+
|
31 |
+
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
|
32 |
+
|
33 |
+
## Highlights
|
34 |
+
|
35 |
+
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
|
36 |
+
|
37 |
+
- **Efficient and scalable**:
|
38 |
+
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
|
39 |
+
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
40 |
+
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
41 |
+
|
42 |
+
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
|
43 |
+
|
44 |
+
## Installation
|
45 |
+
|
46 |
+
### Python Package
|
47 |
+
|
48 |
+
Install the library using `pip`:
|
49 |
+
|
50 |
+
```bash
|
51 |
+
pip install trl
|
52 |
+
```
|
53 |
+
|
54 |
+
### From source
|
55 |
+
|
56 |
+
If you want to use the latest features before an official release, you can install TRL from source:
|
57 |
+
|
58 |
+
```bash
|
59 |
+
pip install git+https://github.com/huggingface/trl.git
|
60 |
+
```
|
61 |
+
|
62 |
+
### Repository
|
63 |
+
|
64 |
+
If you want to use the examples you can clone the repository with the following command:
|
65 |
+
|
66 |
+
```bash
|
67 |
+
git clone https://github.com/huggingface/trl.git
|
68 |
+
```
|
69 |
+
|
70 |
+
## Quick Start
|
71 |
+
|
72 |
+
|
73 |
+
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
74 |
+
|
75 |
+
### `SFTTrainer`
|
76 |
+
|
77 |
+
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
|
78 |
+
|
79 |
+
```python
|
80 |
+
from trl import SFTTrainer
|
81 |
+
from datasets import load_dataset
|
82 |
+
|
83 |
+
dataset = load_dataset("trl-lib/Capybara", split="train")
|
84 |
+
|
85 |
+
trainer = SFTTrainer(
|
86 |
+
model="Qwen/Qwen2.5-0.5B",
|
87 |
+
train_dataset=dataset,
|
88 |
+
)
|
89 |
+
trainer.train()
|
90 |
+
```
|
91 |
+
|
92 |
+
### `GRPOTrainer`
|
93 |
+
|
94 |
+
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
95 |
+
|
96 |
+
```python
|
97 |
+
from datasets import load_dataset
|
98 |
+
from trl import GRPOTrainer
|
99 |
+
|
100 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
101 |
+
|
102 |
+
# Dummy reward function: count the number of unique characters in the completions
|
103 |
+
def reward_num_unique_chars(completions, **kwargs):
|
104 |
+
return [len(set(c)) for c in completions]
|
105 |
+
|
106 |
+
trainer = GRPOTrainer(
|
107 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
108 |
+
reward_funcs=reward_num_unique_chars,
|
109 |
+
train_dataset=dataset,
|
110 |
+
)
|
111 |
+
trainer.train()
|
112 |
+
```
|
113 |
+
|
114 |
+
### `DPOTrainer`
|
115 |
+
|
116 |
+
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
117 |
+
|
118 |
+
```python
|
119 |
+
from datasets import load_dataset
|
120 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
121 |
+
from trl import DPOConfig, DPOTrainer
|
122 |
+
|
123 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
124 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
125 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
126 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
127 |
+
trainer = DPOTrainer(
|
128 |
+
model=model,
|
129 |
+
args=training_args,
|
130 |
+
train_dataset=dataset,
|
131 |
+
processing_class=tokenizer
|
132 |
+
)
|
133 |
+
trainer.train()
|
134 |
+
```
|
135 |
+
|
136 |
+
### `RewardTrainer`
|
137 |
+
|
138 |
+
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
139 |
+
|
140 |
+
```python
|
141 |
+
from trl import RewardConfig, RewardTrainer
|
142 |
+
from datasets import load_dataset
|
143 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
144 |
+
|
145 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
146 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
147 |
+
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
|
148 |
+
)
|
149 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
150 |
+
|
151 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
152 |
+
|
153 |
+
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
|
154 |
+
trainer = RewardTrainer(
|
155 |
+
args=training_args,
|
156 |
+
model=model,
|
157 |
+
processing_class=tokenizer,
|
158 |
+
train_dataset=dataset,
|
159 |
+
)
|
160 |
+
trainer.train()
|
161 |
+
```
|
162 |
+
|
163 |
+
## Command Line Interface (CLI)
|
164 |
+
|
165 |
+
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
|
166 |
+
|
167 |
+
**SFT:**
|
168 |
+
|
169 |
+
```bash
|
170 |
+
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
171 |
+
--dataset_name trl-lib/Capybara \
|
172 |
+
--output_dir Qwen2.5-0.5B-SFT
|
173 |
+
```
|
174 |
+
|
175 |
+
**DPO:**
|
176 |
+
|
177 |
+
```bash
|
178 |
+
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
179 |
+
--dataset_name argilla/Capybara-Preferences \
|
180 |
+
--output_dir Qwen2.5-0.5B-DPO
|
181 |
+
```
|
182 |
+
|
183 |
+
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
184 |
+
|
185 |
+
## Development
|
186 |
+
|
187 |
+
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
188 |
+
|
189 |
+
```bash
|
190 |
+
git clone https://github.com/huggingface/trl.git
|
191 |
+
cd trl/
|
192 |
+
pip install -e .[dev]
|
193 |
+
```
|
194 |
+
|
195 |
+
## Citation
|
196 |
+
|
197 |
+
```bibtex
|
198 |
+
@misc{vonwerra2022trl,
|
199 |
+
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
200 |
+
title = {TRL: Transformer Reinforcement Learning},
|
201 |
+
year = {2020},
|
202 |
+
publisher = {GitHub},
|
203 |
+
journal = {GitHub repository},
|
204 |
+
howpublished = {\url{https://github.com/huggingface/trl}}
|
205 |
+
}
|
206 |
+
```
|
207 |
+
|
208 |
+
## License
|
209 |
+
|
210 |
+
This repository's source code is available under the [Apache-2.0 License](LICENSE).
|
commands/run_dpo.sh
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
3 |
+
# but defaults to QLoRA + PEFT
|
4 |
+
OUTPUT_DIR="test_dpo/"
|
5 |
+
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
6 |
+
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
7 |
+
MAX_STEPS=5
|
8 |
+
BATCH_SIZE=2
|
9 |
+
SEQ_LEN=128
|
10 |
+
|
11 |
+
# Handle extra arguments in case one passes accelerate configs.
|
12 |
+
EXTRA_ACCELERATE_ARGS=""
|
13 |
+
EXTRA_TRAINING_ARGS="""--use_peft \
|
14 |
+
--load_in_4bit
|
15 |
+
"""
|
16 |
+
|
17 |
+
# This is a hack to get the number of available GPUs
|
18 |
+
NUM_GPUS=2
|
19 |
+
|
20 |
+
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
21 |
+
EXTRA_ACCELERATE_ARGS=""
|
22 |
+
else
|
23 |
+
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
24 |
+
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
25 |
+
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
26 |
+
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
27 |
+
EXTRA_TRAINING_ARGS="--fp16"
|
28 |
+
else
|
29 |
+
echo "Keeping QLoRA + PEFT"
|
30 |
+
fi
|
31 |
+
fi
|
32 |
+
|
33 |
+
|
34 |
+
CMD="""
|
35 |
+
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
36 |
+
--num_processes $NUM_GPUS \
|
37 |
+
--mixed_precision 'fp16' \
|
38 |
+
`pwd`/trl/scripts/dpo.py \
|
39 |
+
--model_name_or_path $MODEL_NAME \
|
40 |
+
--dataset_name $DATASET_NAME \
|
41 |
+
--output_dir $OUTPUT_DIR \
|
42 |
+
--max_steps $MAX_STEPS \
|
43 |
+
--per_device_train_batch_size $BATCH_SIZE \
|
44 |
+
--max_length $SEQ_LEN \
|
45 |
+
$EXTRA_TRAINING_ARGS
|
46 |
+
"""
|
47 |
+
|
48 |
+
echo "Starting program..."
|
49 |
+
|
50 |
+
{ # try
|
51 |
+
echo $CMD
|
52 |
+
eval "$CMD"
|
53 |
+
} || { # catch
|
54 |
+
# save log for exception
|
55 |
+
echo "Operation Failed!"
|
56 |
+
exit 1
|
57 |
+
}
|
58 |
+
exit 0
|
commands/run_sft.sh
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
3 |
+
# but defaults to QLoRA + PEFT
|
4 |
+
OUTPUT_DIR="test_sft/"
|
5 |
+
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
6 |
+
DATASET_NAME="stanfordnlp/imdb"
|
7 |
+
MAX_STEPS=5
|
8 |
+
BATCH_SIZE=2
|
9 |
+
SEQ_LEN=128
|
10 |
+
|
11 |
+
|
12 |
+
# Handle extra arguments in case one passes accelerate configs.
|
13 |
+
EXTRA_ACCELERATE_ARGS=""
|
14 |
+
EXTRA_TRAINING_ARGS="""--use_peft \
|
15 |
+
--load_in_4bit
|
16 |
+
"""
|
17 |
+
|
18 |
+
# Set your number of GPUs here
|
19 |
+
NUM_GPUS=2
|
20 |
+
|
21 |
+
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
22 |
+
EXTRA_ACCELERATE_ARGS=""
|
23 |
+
else
|
24 |
+
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
25 |
+
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
26 |
+
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
27 |
+
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
28 |
+
EXTRA_TRAINING_ARGS="--fp16"
|
29 |
+
else
|
30 |
+
echo "Keeping QLoRA + PEFT"
|
31 |
+
fi
|
32 |
+
fi
|
33 |
+
|
34 |
+
|
35 |
+
CMD="""
|
36 |
+
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
37 |
+
--num_processes $NUM_GPUS \
|
38 |
+
--mixed_precision 'fp16' \
|
39 |
+
`pwd`/trl/scripts/sft.py \
|
40 |
+
--model_name $MODEL_NAME \
|
41 |
+
--dataset_name $DATASET_NAME \
|
42 |
+
--output_dir $OUTPUT_DIR \
|
43 |
+
--max_steps $MAX_STEPS \
|
44 |
+
--per_device_train_batch_size $BATCH_SIZE \
|
45 |
+
--max_length $SEQ_LEN \
|
46 |
+
$EXTRA_TRAINING_ARGS
|
47 |
+
"""
|
48 |
+
|
49 |
+
echo "Starting program..."
|
50 |
+
|
51 |
+
{ # try
|
52 |
+
echo $CMD
|
53 |
+
eval "$CMD"
|
54 |
+
} || { # catch
|
55 |
+
# save log for exception
|
56 |
+
echo "Operation Failed!"
|
57 |
+
exit 1
|
58 |
+
}
|
59 |
+
exit 0
|
docker-compose.yml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
services:
|
2 |
+
workspace:
|
3 |
+
build:
|
4 |
+
context: .
|
5 |
+
dockerfile: Dockerfile
|
docker/trl-latest-gpu/Dockerfile
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Builds GPU docker image of PyTorch
|
2 |
+
# Uses multi-staged approach to reduce size
|
3 |
+
# Stage 1
|
4 |
+
# Use base conda image to reduce time
|
5 |
+
FROM continuumio/miniconda3:latest AS compile-image
|
6 |
+
# Specify py version
|
7 |
+
ENV PYTHON_VERSION=3.10
|
8 |
+
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
9 |
+
RUN apt-get update && \
|
10 |
+
apt-get install -y curl git wget software-properties-common git-lfs && \
|
11 |
+
apt-get clean && \
|
12 |
+
rm -rf /var/lib/apt/lists*
|
13 |
+
|
14 |
+
# Install audio-related libraries
|
15 |
+
RUN apt-get update && \
|
16 |
+
apt install -y ffmpeg
|
17 |
+
|
18 |
+
RUN apt install -y libsndfile1-dev
|
19 |
+
RUN git lfs install
|
20 |
+
|
21 |
+
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
22 |
+
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
23 |
+
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
24 |
+
|
25 |
+
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
26 |
+
# We don't install pytorch here yet since CUDA isn't available
|
27 |
+
# instead we use the direct torch wheel
|
28 |
+
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
29 |
+
# Activate our bash shell
|
30 |
+
RUN chsh -s /bin/bash
|
31 |
+
SHELL ["/bin/bash", "-c"]
|
32 |
+
|
33 |
+
# Stage 2
|
34 |
+
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
35 |
+
COPY --from=compile-image /opt/conda /opt/conda
|
36 |
+
ENV PATH /opt/conda/bin:$PATH
|
37 |
+
|
38 |
+
RUN chsh -s /bin/bash
|
39 |
+
SHELL ["/bin/bash", "-c"]
|
40 |
+
RUN source activate trl && \
|
41 |
+
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
42 |
+
|
43 |
+
# Install apt libs
|
44 |
+
RUN apt-get update && \
|
45 |
+
apt-get install -y curl git wget && \
|
46 |
+
apt-get clean && \
|
47 |
+
rm -rf /var/lib/apt/lists*
|
48 |
+
|
49 |
+
# Activate the conda env and install transformers + accelerate from source
|
50 |
+
RUN source activate trl && \
|
51 |
+
python3 -m pip install -U --no-cache-dir \
|
52 |
+
librosa \
|
53 |
+
"soundfile>=0.12.1" \
|
54 |
+
scipy \
|
55 |
+
transformers \
|
56 |
+
accelerate \
|
57 |
+
peft \
|
58 |
+
trl[test]@git+https://github.com/huggingface/trl
|
59 |
+
|
60 |
+
RUN source activate trl && \
|
61 |
+
pip freeze | grep trl
|
62 |
+
|
63 |
+
RUN echo "source activate trl" >> ~/.profile
|
64 |
+
|
65 |
+
# Activate the virtualenv
|
66 |
+
CMD ["/bin/bash"]
|
docker/trl-source-gpu/Dockerfile
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Builds GPU docker image of PyTorch
|
2 |
+
# Uses multi-staged approach to reduce size
|
3 |
+
# Stage 1
|
4 |
+
# Use base conda image to reduce time
|
5 |
+
FROM continuumio/miniconda3:latest AS compile-image
|
6 |
+
# Specify py version
|
7 |
+
ENV PYTHON_VERSION=3.10
|
8 |
+
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
9 |
+
RUN apt-get update && \
|
10 |
+
apt-get install -y curl git wget software-properties-common git-lfs && \
|
11 |
+
apt-get clean && \
|
12 |
+
rm -rf /var/lib/apt/lists*
|
13 |
+
|
14 |
+
# Install audio-related libraries
|
15 |
+
RUN apt-get update && \
|
16 |
+
apt install -y ffmpeg
|
17 |
+
|
18 |
+
RUN apt install -y libsndfile1-dev
|
19 |
+
RUN git lfs install
|
20 |
+
|
21 |
+
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
22 |
+
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
23 |
+
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
24 |
+
|
25 |
+
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
26 |
+
# We don't install pytorch here yet since CUDA isn't available
|
27 |
+
# instead we use the direct torch wheel
|
28 |
+
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
29 |
+
# Activate our bash shell
|
30 |
+
RUN chsh -s /bin/bash
|
31 |
+
SHELL ["/bin/bash", "-c"]
|
32 |
+
|
33 |
+
# Stage 2
|
34 |
+
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
35 |
+
COPY --from=compile-image /opt/conda /opt/conda
|
36 |
+
ENV PATH /opt/conda/bin:$PATH
|
37 |
+
|
38 |
+
RUN chsh -s /bin/bash
|
39 |
+
SHELL ["/bin/bash", "-c"]
|
40 |
+
RUN source activate trl && \
|
41 |
+
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
42 |
+
|
43 |
+
# Install apt libs
|
44 |
+
RUN apt-get update && \
|
45 |
+
apt-get install -y curl git wget && \
|
46 |
+
apt-get clean && \
|
47 |
+
rm -rf /var/lib/apt/lists*
|
48 |
+
|
49 |
+
# Activate the conda env and install transformers + accelerate from source
|
50 |
+
RUN source activate trl && \
|
51 |
+
python3 -m pip install -U --no-cache-dir \
|
52 |
+
librosa \
|
53 |
+
"soundfile>=0.12.1" \
|
54 |
+
scipy \
|
55 |
+
git+https://github.com/huggingface/transformers \
|
56 |
+
git+https://github.com/huggingface/accelerate \
|
57 |
+
git+https://github.com/huggingface/peft \
|
58 |
+
trl[test]@git+https://github.com/huggingface/trl
|
59 |
+
|
60 |
+
RUN source activate trl && \
|
61 |
+
pip freeze | grep transformers
|
62 |
+
|
63 |
+
RUN echo "source activate trl" >> ~/.profile
|
64 |
+
|
65 |
+
# Activate the virtualenv
|
66 |
+
CMD ["/bin/bash"]
|
docs/source/_toctree.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
- sections:
|
2 |
+
- local: index
|
3 |
+
title: TRL
|
4 |
+
- local: installation
|
5 |
+
title: Installation
|
6 |
+
- local: quickstart
|
7 |
+
title: Quickstart
|
8 |
+
title: Getting started
|
9 |
+
- sections:
|
10 |
+
- local: dataset_formats
|
11 |
+
title: Dataset Formats
|
12 |
+
- local: how_to_train
|
13 |
+
title: Training FAQ
|
14 |
+
- local: logging
|
15 |
+
title: Understanding Logs
|
16 |
+
title: Conceptual Guides
|
17 |
+
- sections:
|
18 |
+
- local: clis
|
19 |
+
title: Command Line Interface (CLI)
|
20 |
+
- local: customization
|
21 |
+
title: Customizing the Training
|
22 |
+
- local: reducing_memory_usage
|
23 |
+
title: Reducing Memory Usage
|
24 |
+
- local: speeding_up_training
|
25 |
+
title: Speeding Up Training
|
26 |
+
- local: distributing_training
|
27 |
+
title: Distributing Training
|
28 |
+
- local: use_model
|
29 |
+
title: Using Trained Models
|
30 |
+
title: How-to guides
|
31 |
+
- sections:
|
32 |
+
- local: deepspeed_integration
|
33 |
+
title: DeepSpeed
|
34 |
+
- local: liger_kernel_integration
|
35 |
+
title: Liger Kernel
|
36 |
+
- local: peft_integration
|
37 |
+
title: PEFT
|
38 |
+
- local: unsloth_integration
|
39 |
+
title: Unsloth
|
40 |
+
- local: vllm_integration
|
41 |
+
title: vLLM
|
42 |
+
title: Integrations
|
43 |
+
- sections:
|
44 |
+
- local: example_overview
|
45 |
+
title: Example Overview
|
46 |
+
- local: community_tutorials
|
47 |
+
title: Community Tutorials
|
48 |
+
- local: sentiment_tuning
|
49 |
+
title: Sentiment Tuning
|
50 |
+
- local: using_llama_models
|
51 |
+
title: Training StackLlama
|
52 |
+
- local: detoxifying_a_lm
|
53 |
+
title: Detoxifying a Language Model
|
54 |
+
- local: multi_adapter_rl
|
55 |
+
title: Multi Adapter RLHF
|
56 |
+
- local: training_vlm_sft
|
57 |
+
title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
|
58 |
+
title: Examples
|
59 |
+
- sections:
|
60 |
+
- sections: # Sorted alphabetically
|
61 |
+
- local: alignprop_trainer
|
62 |
+
title: AlignProp
|
63 |
+
- local: bco_trainer
|
64 |
+
title: BCO
|
65 |
+
- local: cpo_trainer
|
66 |
+
title: CPO
|
67 |
+
- local: ddpo_trainer
|
68 |
+
title: DDPO
|
69 |
+
- local: dpo_trainer
|
70 |
+
title: DPO
|
71 |
+
- local: online_dpo_trainer
|
72 |
+
title: Online DPO
|
73 |
+
- local: gkd_trainer
|
74 |
+
title: GKD
|
75 |
+
- local: grpo_trainer
|
76 |
+
title: GRPO
|
77 |
+
- local: kto_trainer
|
78 |
+
title: KTO
|
79 |
+
- local: nash_md_trainer
|
80 |
+
title: Nash-MD
|
81 |
+
- local: orpo_trainer
|
82 |
+
title: ORPO
|
83 |
+
- local: ppo_trainer
|
84 |
+
title: PPO
|
85 |
+
- local: prm_trainer
|
86 |
+
title: PRM
|
87 |
+
- local: reward_trainer
|
88 |
+
title: Reward
|
89 |
+
- local: rloo_trainer
|
90 |
+
title: RLOO
|
91 |
+
- local: sft_trainer
|
92 |
+
title: SFT
|
93 |
+
- local: iterative_sft_trainer
|
94 |
+
title: Iterative SFT
|
95 |
+
- local: xpo_trainer
|
96 |
+
title: XPO
|
97 |
+
title: Trainers
|
98 |
+
- local: models
|
99 |
+
title: Model Classes
|
100 |
+
- local: model_utils
|
101 |
+
title: Model Utilities
|
102 |
+
- local: best_of_n
|
103 |
+
title: Best of N Sampling
|
104 |
+
- local: judges
|
105 |
+
title: Judges
|
106 |
+
- local: callbacks
|
107 |
+
title: Callbacks
|
108 |
+
- local: data_utils
|
109 |
+
title: Data Utilities
|
110 |
+
- local: rewards
|
111 |
+
title: Reward Functions
|
112 |
+
- local: script_utils
|
113 |
+
title: Script Utilities
|
114 |
+
- local: others
|
115 |
+
title: Others
|
116 |
+
title: API
|
docs/source/alignprop_trainer.md
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
2 |
+
|
3 |
+
[](https://huggingface.co/models?other=alignprop,trl)
|
4 |
+
|
5 |
+
## The why
|
6 |
+
|
7 |
+
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
8 |
+
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
9 |
+
|
10 |
+
<div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/reward_tuning.png"/></div>
|
11 |
+
|
12 |
+
|
13 |
+
## Getting started with `examples/scripts/alignprop.py`
|
14 |
+
|
15 |
+
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
|
16 |
+
|
17 |
+
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
18 |
+
|
19 |
+
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
20 |
+
|
21 |
+
```batch
|
22 |
+
python alignprop.py --hf_user_access_token <token>
|
23 |
+
```
|
24 |
+
|
25 |
+
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
26 |
+
|
27 |
+
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
28 |
+
|
29 |
+
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
30 |
+
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
31 |
+
|
32 |
+
## Setting up the image logging hook function
|
33 |
+
|
34 |
+
Expect the function to be given a dictionary with keys
|
35 |
+
```python
|
36 |
+
['image', 'prompt', 'prompt_metadata', 'rewards']
|
37 |
+
|
38 |
+
```
|
39 |
+
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
|
40 |
+
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
41 |
+
|
42 |
+
### Key terms
|
43 |
+
|
44 |
+
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
45 |
+
- `prompt` : The prompt is the text that is used to generate the image
|
46 |
+
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
47 |
+
- `image` : The image generated by the Stable Diffusion model
|
48 |
+
|
49 |
+
Example code for logging sampled images with `wandb` is given below.
|
50 |
+
|
51 |
+
```python
|
52 |
+
# for logging these images to wandb
|
53 |
+
|
54 |
+
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
55 |
+
# For the sake of this example, we only care about the last batch
|
56 |
+
# hence we extract the last element of the list
|
57 |
+
result = {}
|
58 |
+
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
|
59 |
+
for i, image in enumerate(images):
|
60 |
+
pil = Image.fromarray(
|
61 |
+
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
62 |
+
)
|
63 |
+
pil = pil.resize((256, 256))
|
64 |
+
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
65 |
+
accelerate_logger.log_images(
|
66 |
+
result,
|
67 |
+
step=global_step,
|
68 |
+
)
|
69 |
+
|
70 |
+
```
|
71 |
+
|
72 |
+
### Using the finetuned model
|
73 |
+
|
74 |
+
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
75 |
+
|
76 |
+
```python
|
77 |
+
from diffusers import StableDiffusionPipeline
|
78 |
+
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
79 |
+
pipeline.to("cuda")
|
80 |
+
|
81 |
+
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
|
82 |
+
|
83 |
+
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
84 |
+
results = pipeline(prompts)
|
85 |
+
|
86 |
+
for prompt, image in zip(prompts,results.images):
|
87 |
+
image.save(f"dump/{prompt}.png")
|
88 |
+
```
|
89 |
+
|
90 |
+
## Credits
|
91 |
+
|
92 |
+
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
93 |
+
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
|
docs/source/bco_trainer.md
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BCO Trainer
|
2 |
+
|
3 |
+
[](https://huggingface.co/models?other=bco,trl)
|
4 |
+
|
5 |
+
TRL supports the Binary Classifier Optimization (BCO).
|
6 |
+
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
7 |
+
For a full example have a look at [`examples/scripts/bco.py`].
|
8 |
+
|
9 |
+
## Expected dataset type
|
10 |
+
|
11 |
+
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
|
12 |
+
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
13 |
+
|
14 |
+
## Expected model format
|
15 |
+
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
16 |
+
|
17 |
+
## Using the `BCOTrainer`
|
18 |
+
|
19 |
+
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
20 |
+
|
21 |
+
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
```py
|
26 |
+
training_args = BCOConfig(
|
27 |
+
beta=0.1,
|
28 |
+
)
|
29 |
+
|
30 |
+
bco_trainer = BCOTrainer(
|
31 |
+
model,
|
32 |
+
model_ref,
|
33 |
+
args=training_args,
|
34 |
+
train_dataset=train_dataset,
|
35 |
+
processing_class=tokenizer,
|
36 |
+
)
|
37 |
+
```
|
38 |
+
After this one can then call:
|
39 |
+
|
40 |
+
```py
|
41 |
+
bco_trainer.train()
|
42 |
+
```
|
43 |
+
|
44 |
+
## Underlying Distribution matching (UDM)
|
45 |
+
|
46 |
+
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
|
47 |
+
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
|
48 |
+
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
|
49 |
+
|
50 |
+
Choose an embedding model and tokenizer:
|
51 |
+
|
52 |
+
```py
|
53 |
+
embedding_model = AutoModel.from_pretrained(your_model_id)
|
54 |
+
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
55 |
+
|
56 |
+
# customize this function depending on your embedding model
|
57 |
+
def embed_prompt(input_ids, attention_mask, model):
|
58 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
59 |
+
return outputs.last_hidden_state.mean(dim=1)
|
60 |
+
|
61 |
+
embedding_model = Accelerator().prepare_model(self.embedding_model)
|
62 |
+
embedding_func = partial(embed_prompt, model=embedding_model)
|
63 |
+
```
|
64 |
+
|
65 |
+
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
66 |
+
|
67 |
+
```py
|
68 |
+
training_args = BCOConfig(
|
69 |
+
beta=0.1,
|
70 |
+
prompt_sample_size=512,
|
71 |
+
)
|
72 |
+
|
73 |
+
bco_trainer = BCOTrainer(
|
74 |
+
model,
|
75 |
+
model_ref,
|
76 |
+
args=training_args,
|
77 |
+
train_dataset=train_dataset,
|
78 |
+
processing_class=tokenizer,
|
79 |
+
embedding_func=embedding_func,
|
80 |
+
embedding_tokenizer=self.embedding_tokenizer,
|
81 |
+
)
|
82 |
+
|
83 |
+
bco_trainer.train()
|
84 |
+
```
|
85 |
+
|
86 |
+
### For Mixture of Experts Models: Enabling the auxiliary loss
|
87 |
+
|
88 |
+
MOEs are the most efficient if the load is about equally distributed between experts.
|
89 |
+
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
90 |
+
|
91 |
+
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
92 |
+
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
93 |
+
|
94 |
+
## BCOTrainer
|
95 |
+
|
96 |
+
[[autodoc]] BCOTrainer
|
97 |
+
|
98 |
+
## BCOConfig
|
99 |
+
|
100 |
+
[[autodoc]] BCOConfig
|
docs/source/best_of_n.md
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
2 |
+
|
3 |
+
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
4 |
+
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
5 |
+
|
6 |
+
## Usage
|
7 |
+
|
8 |
+
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
|
9 |
+
|
10 |
+
```python
|
11 |
+
|
12 |
+
from transformers import pipeline, AutoTokenizer
|
13 |
+
from trl import AutoModelForCausalLMWithValueHead
|
14 |
+
from trl.core import LengthSampler
|
15 |
+
from trl.extras import BestOfNSampler
|
16 |
+
|
17 |
+
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
|
18 |
+
reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
|
19 |
+
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
20 |
+
tokenizer.pad_token = tokenizer.eos_token
|
21 |
+
|
22 |
+
|
23 |
+
# callable that takes a list of raw text and returns a list of corresponding reward scores
|
24 |
+
def queries_to_scores(list_of_strings):
|
25 |
+
return [output["score"] for output in reward_pipe(list_of_strings)]
|
26 |
+
|
27 |
+
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
|
28 |
+
|
29 |
+
|
30 |
+
```
|
31 |
+
|
32 |
+
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
|
33 |
+
|
34 |
+
```python
|
35 |
+
|
36 |
+
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
|
37 |
+
|
38 |
+
```
|
39 |
+
The default sample size is 4, but you can change it at the time of instance initialization like so
|
40 |
+
|
41 |
+
```python
|
42 |
+
|
43 |
+
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
|
44 |
+
|
45 |
+
```
|
46 |
+
|
47 |
+
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
|
48 |
+
|
49 |
+
```python
|
50 |
+
|
51 |
+
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
|
52 |
+
|
53 |
+
```
|
54 |
+
|
55 |
+
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
56 |
+
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
57 |
+
|
58 |
+
```python
|
59 |
+
|
60 |
+
from transformers import GenerationConfig
|
61 |
+
|
62 |
+
generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
|
63 |
+
|
64 |
+
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
|
65 |
+
|
66 |
+
best_of_n.generate(query_tensors, device=device)
|
67 |
+
|
68 |
+
```
|
69 |
+
|
70 |
+
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
|
71 |
+
|
72 |
+
|
docs/source/callbacks.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Callbacks
|
2 |
+
|
3 |
+
## SyncRefModelCallback
|
4 |
+
|
5 |
+
[[autodoc]] SyncRefModelCallback
|
6 |
+
|
7 |
+
## RichProgressCallback
|
8 |
+
|
9 |
+
[[autodoc]] RichProgressCallback
|
10 |
+
|
11 |
+
## WinRateCallback
|
12 |
+
|
13 |
+
[[autodoc]] WinRateCallback
|
14 |
+
|
15 |
+
## LogCompletionsCallback
|
16 |
+
|
17 |
+
[[autodoc]] LogCompletionsCallback
|
18 |
+
|
19 |
+
## MergeModelCallback
|
20 |
+
|
21 |
+
[[autodoc]] MergeModelCallback
|
docs/source/clis.md
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Command Line Interfaces (CLIs)
|
2 |
+
|
3 |
+
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
4 |
+
|
5 |
+
Currently supported commands are:
|
6 |
+
|
7 |
+
#### Training Commands
|
8 |
+
|
9 |
+
- `trl dpo`: fine-tune a LLM with DPO
|
10 |
+
- `trl grpo`: fine-tune a LLM with GRPO
|
11 |
+
- `trl kto`: fine-tune a LLM with KTO
|
12 |
+
- `trl sft`: fine-tune a LLM with SFT
|
13 |
+
|
14 |
+
#### Other Commands
|
15 |
+
|
16 |
+
- `trl env`: get the system information
|
17 |
+
- `trl vllm-serve`: serve a model with vLLM
|
18 |
+
|
19 |
+
## Fine-Tuning with the TRL CLI
|
20 |
+
|
21 |
+
### Basic Usage
|
22 |
+
|
23 |
+
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
|
24 |
+
|
25 |
+
<hfoptions id="command_line">
|
26 |
+
<hfoption id="SFT">
|
27 |
+
|
28 |
+
```bash
|
29 |
+
trl sft \
|
30 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
31 |
+
--dataset_name stanfordnlp/imdb
|
32 |
+
```
|
33 |
+
|
34 |
+
</hfoption>
|
35 |
+
<hfoption id="DPO">
|
36 |
+
|
37 |
+
```bash
|
38 |
+
trl dpo \
|
39 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
40 |
+
--dataset_name anthropic/hh-rlhf
|
41 |
+
```
|
42 |
+
|
43 |
+
</hfoption>
|
44 |
+
</hfoptions>
|
45 |
+
|
46 |
+
### Using Configuration Files
|
47 |
+
|
48 |
+
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
|
49 |
+
|
50 |
+
<hfoptions id="config_file">
|
51 |
+
<hfoption id="SFT">
|
52 |
+
|
53 |
+
```yaml
|
54 |
+
# sft_config.yaml
|
55 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
56 |
+
dataset_name: stanfordnlp/imdb
|
57 |
+
```
|
58 |
+
|
59 |
+
Launch with:
|
60 |
+
|
61 |
+
```bash
|
62 |
+
trl sft --config sft_config.yaml
|
63 |
+
```
|
64 |
+
|
65 |
+
</hfoption>
|
66 |
+
<hfoption id="DPO">
|
67 |
+
|
68 |
+
```yaml
|
69 |
+
# dpo_config.yaml
|
70 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
71 |
+
dataset_name: anthropic/hh-rlhf
|
72 |
+
```
|
73 |
+
|
74 |
+
Launch with:
|
75 |
+
|
76 |
+
```bash
|
77 |
+
trl dpo --config dpo_config.yaml
|
78 |
+
```
|
79 |
+
|
80 |
+
</hfoption>
|
81 |
+
</hfoptions>
|
82 |
+
|
83 |
+
### Scaling Up with Accelerate
|
84 |
+
|
85 |
+
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
|
86 |
+
|
87 |
+
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
|
88 |
+
|
89 |
+
<hfoptions id="launch_args">
|
90 |
+
<hfoption id="SFT inline">
|
91 |
+
|
92 |
+
```bash
|
93 |
+
trl sft \
|
94 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
95 |
+
--dataset_name stanfordnlp/imdb \
|
96 |
+
--num_processes 4
|
97 |
+
```
|
98 |
+
|
99 |
+
</hfoption>
|
100 |
+
<hfoption id="SFT w/ config file">
|
101 |
+
|
102 |
+
```yaml
|
103 |
+
# sft_config.yaml
|
104 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
105 |
+
dataset_name: stanfordnlp/imdb
|
106 |
+
num_processes: 4
|
107 |
+
```
|
108 |
+
|
109 |
+
Launch with:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
trl sft --config sft_config.yaml
|
113 |
+
```
|
114 |
+
|
115 |
+
</hfoption>
|
116 |
+
<hfoption id="DPO inline">
|
117 |
+
|
118 |
+
```bash
|
119 |
+
trl dpo \
|
120 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
121 |
+
--dataset_name anthropic/hh-rlhf \
|
122 |
+
--num_processes 4
|
123 |
+
```
|
124 |
+
|
125 |
+
</hfoption>
|
126 |
+
<hfoption id="DPO w/ config file">
|
127 |
+
|
128 |
+
```yaml
|
129 |
+
# dpo_config.yaml
|
130 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
131 |
+
dataset_name: anthropic/hh-rlhf
|
132 |
+
num_processes: 4
|
133 |
+
```
|
134 |
+
|
135 |
+
Launch with:
|
136 |
+
|
137 |
+
```bash
|
138 |
+
trl dpo --config dpo_config.yaml
|
139 |
+
```
|
140 |
+
</hfoption>
|
141 |
+
</hfoptions>
|
142 |
+
|
143 |
+
### Using `--accelerate_config` for Accelerate Configuration
|
144 |
+
|
145 |
+
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
146 |
+
|
147 |
+
* the name of a predefined config profile (built into TRL), or
|
148 |
+
* a path to a custom Accelerate YAML config file.
|
149 |
+
|
150 |
+
#### Predefined Config Profiles
|
151 |
+
|
152 |
+
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
153 |
+
|
154 |
+
| Name | Description |
|
155 |
+
| ------------ | ----------------------------------- |
|
156 |
+
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
157 |
+
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
158 |
+
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
159 |
+
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
160 |
+
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
161 |
+
| `multi_gpu` | Multi-GPU training |
|
162 |
+
| `single_gpu` | Single-GPU training |
|
163 |
+
|
164 |
+
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
165 |
+
|
166 |
+
#### Example Usage
|
167 |
+
|
168 |
+
<hfoptions id="accelerate_config">
|
169 |
+
<hfoption id="SFT inline">
|
170 |
+
|
171 |
+
```bash
|
172 |
+
trl sft \
|
173 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
174 |
+
--dataset_name stanfordnlp/imdb \
|
175 |
+
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
176 |
+
```
|
177 |
+
|
178 |
+
</hfoption>
|
179 |
+
<hfoption id="SFT w/ config file">
|
180 |
+
|
181 |
+
```yaml
|
182 |
+
# sft_config.yaml
|
183 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
184 |
+
dataset_name: stanfordnlp/imdb
|
185 |
+
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
186 |
+
```
|
187 |
+
|
188 |
+
Launch with:
|
189 |
+
|
190 |
+
```bash
|
191 |
+
trl sft --config sft_config.yaml
|
192 |
+
```
|
193 |
+
|
194 |
+
</hfoption>
|
195 |
+
<hfoption id="DPO inline">
|
196 |
+
|
197 |
+
```bash
|
198 |
+
trl dpo \
|
199 |
+
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
200 |
+
--dataset_name anthropic/hh-rlhf \
|
201 |
+
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
202 |
+
```
|
203 |
+
|
204 |
+
</hfoption>
|
205 |
+
<hfoption id="DPO w/ config file">
|
206 |
+
|
207 |
+
```yaml
|
208 |
+
# dpo_config.yaml
|
209 |
+
model_name_or_path: Qwen/Qwen2.5-0.5B
|
210 |
+
dataset_name: anthropic/hh-rlhf
|
211 |
+
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
212 |
+
```
|
213 |
+
|
214 |
+
Launch with:
|
215 |
+
|
216 |
+
```bash
|
217 |
+
trl dpo --config dpo_config.yaml
|
218 |
+
```
|
219 |
+
</hfoption>
|
220 |
+
</hfoptions>
|
221 |
+
|
222 |
+
## Getting the System Information
|
223 |
+
|
224 |
+
You can get the system information by running the following command:
|
225 |
+
|
226 |
+
```bash
|
227 |
+
trl env
|
228 |
+
```
|
229 |
+
|
230 |
+
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
|
231 |
+
|
232 |
+
```txt
|
233 |
+
Copy-paste the following information when reporting an issue:
|
234 |
+
|
235 |
+
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
236 |
+
- Python version: 3.11.9
|
237 |
+
- PyTorch version: 2.4.1
|
238 |
+
- accelerator(s): NVIDIA H100 80GB HBM3
|
239 |
+
- Transformers version: 4.45.0.dev0
|
240 |
+
- Accelerate version: 0.34.2
|
241 |
+
- Accelerate config:
|
242 |
+
- compute_environment: LOCAL_MACHINE
|
243 |
+
- distributed_type: DEEPSPEED
|
244 |
+
- mixed_precision: no
|
245 |
+
- use_cpu: False
|
246 |
+
- debug: False
|
247 |
+
- num_processes: 4
|
248 |
+
- machine_rank: 0
|
249 |
+
- num_machines: 1
|
250 |
+
- rdzv_backend: static
|
251 |
+
- same_network: True
|
252 |
+
- main_training_function: main
|
253 |
+
- enable_cpu_affinity: False
|
254 |
+
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
|
255 |
+
- downcast_bf16: no
|
256 |
+
- tpu_use_cluster: False
|
257 |
+
- tpu_use_sudo: False
|
258 |
+
- tpu_env: []
|
259 |
+
- Datasets version: 3.0.0
|
260 |
+
- HF Hub version: 0.24.7
|
261 |
+
- TRL version: 0.12.0.dev0+acb4d70
|
262 |
+
- bitsandbytes version: 0.41.1
|
263 |
+
- DeepSpeed version: 0.15.1
|
264 |
+
- Diffusers version: 0.30.3
|
265 |
+
- Liger-Kernel version: 0.3.0
|
266 |
+
- LLM-Blender version: 0.0.2
|
267 |
+
- OpenAI version: 1.46.0
|
268 |
+
- PEFT version: 0.12.0
|
269 |
+
- vLLM version: not installed
|
270 |
+
```
|
271 |
+
|
272 |
+
This information is required when reporting an issue.
|
docs/source/community_tutorials.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Community Tutorials
|
2 |
+
|
3 |
+
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
|
4 |
+
|
5 |
+
# Language Models
|
6 |
+
|
7 |
+
| Task | Class | Description | Author | Tutorial | Colab |
|
8 |
+
| --- | --- | --- | --- | --- | --- |
|
9 |
+
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
10 |
+
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
11 |
+
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
|
12 |
+
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
|
13 |
+
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
|
14 |
+
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
|
15 |
+
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
16 |
+
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
17 |
+
|
18 |
+
<Youtube id="cnGyyM0vOes" />
|
19 |
+
|
20 |
+
# Vision Language Models
|
21 |
+
|
22 |
+
| Task | Class | Description | Author | Tutorial | Colab |
|
23 |
+
| --- | --- | --- | --- | --- | --- |
|
24 |
+
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
|
25 |
+
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
|
26 |
+
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
|
27 |
+
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
|
28 |
+
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
|
29 |
+
|
30 |
+
## Contributing
|
31 |
+
|
32 |
+
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
|
docs/source/cpo_trainer.md
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CPO Trainer
|
2 |
+
|
3 |
+
[](https://huggingface.co/models?other=cpo,trl)
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
|
8 |
+
|
9 |
+
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
10 |
+
|
11 |
+
## Quick start
|
12 |
+
|
13 |
+
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
14 |
+
|
15 |
+
<iframe
|
16 |
+
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
17 |
+
frameborder="0"
|
18 |
+
width="100%"
|
19 |
+
height="560px"
|
20 |
+
></iframe>
|
21 |
+
|
22 |
+
Below is the script to train the model:
|
23 |
+
|
24 |
+
```python
|
25 |
+
# train_cpo.py
|
26 |
+
from datasets import load_dataset
|
27 |
+
from trl import CPOConfig, CPOTrainer
|
28 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
29 |
+
|
30 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
32 |
+
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
33 |
+
|
34 |
+
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
|
35 |
+
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
36 |
+
trainer.train()
|
37 |
+
```
|
38 |
+
|
39 |
+
Execute the script using the following command:
|
40 |
+
|
41 |
+
```bash
|
42 |
+
accelerate launch train_cpo.py
|
43 |
+
```
|
44 |
+
|
45 |
+
## Expected dataset type
|
46 |
+
|
47 |
+
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
48 |
+
|
49 |
+
## Example script
|
50 |
+
|
51 |
+
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
|
52 |
+
|
53 |
+
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
54 |
+
|
55 |
+
```bash
|
56 |
+
accelerate launch examples/scripts/cpo.py \
|
57 |
+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
58 |
+
--dataset_name trl-lib/ultrafeedback_binarized \
|
59 |
+
--num_train_epochs 1 \
|
60 |
+
--logging_steps 25 \
|
61 |
+
--output_dir Qwen2-0.5B-CPO
|
62 |
+
```
|
63 |
+
|
64 |
+
## Logged metrics
|
65 |
+
|
66 |
+
While training and evaluating we record the following reward metrics:
|
67 |
+
|
68 |
+
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
69 |
+
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
70 |
+
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
71 |
+
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
72 |
+
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
73 |
+
|
74 |
+
## CPO variants
|
75 |
+
|
76 |
+
### Simple Preference Optimization (SimPO)
|
77 |
+
|
78 |
+
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
|
79 |
+
|
80 |
+
### CPO-SimPO
|
81 |
+
|
82 |
+
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
|
83 |
+
|
84 |
+
## Loss functions
|
85 |
+
|
86 |
+
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
|
87 |
+
|
88 |
+
| `loss_type=` | Description |
|
89 |
+
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
90 |
+
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
91 |
+
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
92 |
+
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
|
93 |
+
|
94 |
+
### For Mixture of Experts Models: Enabling the auxiliary loss
|
95 |
+
|
96 |
+
MOEs are the most efficient if the load is about equally distributed between experts.
|
97 |
+
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
98 |
+
|
99 |
+
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
100 |
+
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
101 |
+
|
102 |
+
## CPOTrainer
|
103 |
+
|
104 |
+
[[autodoc]] CPOTrainer
|
105 |
+
|
106 |
+
## CPOConfig
|
107 |
+
|
108 |
+
[[autodoc]] CPOConfig
|
docs/source/customization.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training customization
|
2 |
+
|
3 |
+
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
## Use different optimizers and schedulers
|
8 |
+
|
9 |
+
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
10 |
+
|
11 |
+
```python
|
12 |
+
from datasets import load_dataset
|
13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
14 |
+
from torch import optim
|
15 |
+
from trl import DPOConfig, DPOTrainer
|
16 |
+
|
17 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
19 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
20 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
21 |
+
|
22 |
+
optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
|
23 |
+
|
24 |
+
trainer = DPOTrainer(
|
25 |
+
model=model,
|
26 |
+
args=training_args,
|
27 |
+
train_dataset=dataset,
|
28 |
+
tokenizer=tokenizer,
|
29 |
+
optimizers=(optimizer, None),
|
30 |
+
)
|
31 |
+
trainer.train()
|
32 |
+
```
|
33 |
+
|
34 |
+
### Add a learning rate scheduler
|
35 |
+
|
36 |
+
You can also play with your training by adding learning rate schedulers.
|
37 |
+
|
38 |
+
```python
|
39 |
+
from datasets import load_dataset
|
40 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
41 |
+
from torch import optim
|
42 |
+
from trl import DPOConfig, DPOTrainer
|
43 |
+
|
44 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
46 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
47 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
48 |
+
|
49 |
+
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
|
50 |
+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
51 |
+
|
52 |
+
trainer = DPOTrainer(
|
53 |
+
model=model,
|
54 |
+
args=training_args,
|
55 |
+
train_dataset=dataset,
|
56 |
+
tokenizer=tokenizer,
|
57 |
+
optimizers=(optimizer, lr_scheduler),
|
58 |
+
)
|
59 |
+
trainer.train()
|
60 |
+
```
|
61 |
+
|
62 |
+
## Memory efficient fine-tuning by sharing layers
|
63 |
+
|
64 |
+
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
|
65 |
+
|
66 |
+
```python
|
67 |
+
from datasets import load_dataset
|
68 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
69 |
+
from trl import create_reference_model, DPOConfig, DPOTrainer
|
70 |
+
|
71 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
72 |
+
ref_model = create_reference_model(model, num_shared_layers=6)
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
74 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
|
75 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
76 |
+
|
77 |
+
trainer = DPOTrainer(
|
78 |
+
model=model,
|
79 |
+
ref_model=ref_model,
|
80 |
+
args=training_args,
|
81 |
+
train_dataset=dataset,
|
82 |
+
tokenizer=tokenizer,
|
83 |
+
)
|
84 |
+
trainer.train()
|
85 |
+
```
|
86 |
+
|
87 |
+
## Pass 8-bit reference models
|
88 |
+
|
89 |
+
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
90 |
+
|
91 |
+
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
92 |
+
|
93 |
+
```python
|
94 |
+
from datasets import load_dataset
|
95 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
96 |
+
from trl import DPOConfig, DPOTrainer
|
97 |
+
|
98 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
99 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
100 |
+
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
|
101 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
102 |
+
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
103 |
+
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
104 |
+
|
105 |
+
trainer = DPOTrainer(
|
106 |
+
model=model,
|
107 |
+
ref_model=ref_model,
|
108 |
+
args=training_args,
|
109 |
+
train_dataset=dataset,
|
110 |
+
tokenizer=tokenizer,
|
111 |
+
)
|
112 |
+
trainer.train()
|
113 |
+
```
|
114 |
+
|
115 |
+
## Use the accelerator cache optimizer
|
116 |
+
|
117 |
+
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`:
|
118 |
+
|
119 |
+
```python
|
120 |
+
training_args = DPOConfig(..., optimize_device_cache=True)
|
121 |
+
```
|
docs/source/data_utils.md
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Utilities
|
2 |
+
|
3 |
+
## is_conversational
|
4 |
+
|
5 |
+
[[autodoc]] is_conversational
|
6 |
+
|
7 |
+
## apply_chat_template
|
8 |
+
|
9 |
+
[[autodoc]] apply_chat_template
|
10 |
+
|
11 |
+
## maybe_apply_chat_template
|
12 |
+
|
13 |
+
[[autodoc]] maybe_apply_chat_template
|
14 |
+
|
15 |
+
## maybe_convert_to_chatml
|
16 |
+
|
17 |
+
[[autodoc]] maybe_convert_to_chatml
|
18 |
+
|
19 |
+
## extract_prompt
|
20 |
+
|
21 |
+
[[autodoc]] extract_prompt
|
22 |
+
|
23 |
+
## maybe_extract_prompt
|
24 |
+
|
25 |
+
[[autodoc]] maybe_extract_prompt
|
26 |
+
|
27 |
+
## unpair_preference_dataset
|
28 |
+
|
29 |
+
[[autodoc]] unpair_preference_dataset
|
30 |
+
|
31 |
+
## maybe_unpair_preference_dataset
|
32 |
+
|
33 |
+
[[autodoc]] maybe_unpair_preference_dataset
|
34 |
+
|
35 |
+
## pack_dataset
|
36 |
+
|
37 |
+
[[autodoc]] pack_dataset
|
38 |
+
|
39 |
+
## truncate_dataset
|
40 |
+
|
41 |
+
[[autodoc]] truncate_dataset
|
docs/source/dataset_formats.md
ADDED
@@ -0,0 +1,938 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dataset formats and types
|
2 |
+
|
3 |
+
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
|
4 |
+
|
5 |
+
## Overview of the dataset formats and types
|
6 |
+
|
7 |
+
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
|
8 |
+
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
|
9 |
+
|
10 |
+
<table>
|
11 |
+
<tr>
|
12 |
+
<th>Type \ Format</th>
|
13 |
+
<th>Standard</th>
|
14 |
+
<th>Conversational</th>
|
15 |
+
</tr>
|
16 |
+
<tr>
|
17 |
+
<td>Language modeling</td>
|
18 |
+
<td>
|
19 |
+
<pre><code>{"text": "The sky is blue."}</code></pre>
|
20 |
+
</td>
|
21 |
+
<td>
|
22 |
+
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
|
23 |
+
{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
24 |
+
</td>
|
25 |
+
</tr>
|
26 |
+
<tr>
|
27 |
+
<td>Prompt-only</td>
|
28 |
+
<td>
|
29 |
+
<pre><code>{"prompt": "The sky is"}</code></pre>
|
30 |
+
</td>
|
31 |
+
<td>
|
32 |
+
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
|
33 |
+
</td>
|
34 |
+
</tr>
|
35 |
+
<tr>
|
36 |
+
<td>Prompt-completion</td>
|
37 |
+
<td>
|
38 |
+
<pre><code>{"prompt": "The sky is",
|
39 |
+
"completion": " blue."}</code></pre>
|
40 |
+
</td>
|
41 |
+
<td>
|
42 |
+
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
43 |
+
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
44 |
+
</td>
|
45 |
+
</tr>
|
46 |
+
</tr>
|
47 |
+
<tr>
|
48 |
+
<td>Preference</td>
|
49 |
+
<td>
|
50 |
+
<pre><code>{"prompt": "The sky is",
|
51 |
+
"chosen": " blue.",
|
52 |
+
"rejected": " green."}</code></pre>
|
53 |
+
or, with implicit prompt:
|
54 |
+
<pre><code>{"chosen": "The sky is blue.",
|
55 |
+
"rejected": "The sky is green."}</code></pre>
|
56 |
+
</td>
|
57 |
+
<td>
|
58 |
+
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
59 |
+
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
60 |
+
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
|
61 |
+
or, with implicit prompt:
|
62 |
+
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
63 |
+
{"role": "assistant", "content": "It is blue."}],
|
64 |
+
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
65 |
+
{"role": "assistant", "content": "It is green."}]}</code></pre>
|
66 |
+
</td>
|
67 |
+
</tr>
|
68 |
+
<td>Unpaired preference</td>
|
69 |
+
<td>
|
70 |
+
<pre><code>{"prompt": "The sky is",
|
71 |
+
"completion": " blue.",
|
72 |
+
"label": True}</code></pre>
|
73 |
+
</td>
|
74 |
+
<td>
|
75 |
+
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
76 |
+
"completion": [{"role": "assistant", "content": "It is green."}],
|
77 |
+
"label": False}</code></pre>
|
78 |
+
</td>
|
79 |
+
</tr>
|
80 |
+
</tr>
|
81 |
+
<td>Stepwise supervision</td>
|
82 |
+
<td>
|
83 |
+
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
84 |
+
"completions": ["The fractional part of 9.8 is 0.8.",
|
85 |
+
"The fractional part of 9.11 is 0.11.",
|
86 |
+
"0.11 is greater than 0.8.",
|
87 |
+
"Hence, 9.11 > 9.8."],
|
88 |
+
"labels": [True, True, False, False]}</code></pre>
|
89 |
+
</td>
|
90 |
+
<td></td>
|
91 |
+
</tr>
|
92 |
+
</table>
|
93 |
+
|
94 |
+
### Formats
|
95 |
+
|
96 |
+
#### Standard
|
97 |
+
|
98 |
+
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
|
99 |
+
|
100 |
+
```python
|
101 |
+
# Language modeling
|
102 |
+
language_modeling_example = {"text": "The sky is blue."}
|
103 |
+
# Preference
|
104 |
+
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
105 |
+
# Unpaired preference
|
106 |
+
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
107 |
+
```
|
108 |
+
|
109 |
+
#### Conversational
|
110 |
+
|
111 |
+
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
|
112 |
+
|
113 |
+
```python
|
114 |
+
messages = [
|
115 |
+
{"role": "user", "content": "Hello, how are you?"},
|
116 |
+
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
117 |
+
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
118 |
+
]
|
119 |
+
```
|
120 |
+
|
121 |
+
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
|
122 |
+
|
123 |
+
```python
|
124 |
+
# Prompt-completion
|
125 |
+
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
126 |
+
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
127 |
+
# Preference
|
128 |
+
preference_example = {
|
129 |
+
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
130 |
+
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
131 |
+
"rejected": [{"role": "assistant", "content": "It is green."}],
|
132 |
+
}
|
133 |
+
```
|
134 |
+
|
135 |
+
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
136 |
+
|
137 |
+
### Types
|
138 |
+
|
139 |
+
#### Language modeling
|
140 |
+
|
141 |
+
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
|
142 |
+
|
143 |
+
```python
|
144 |
+
# Standard format
|
145 |
+
language_modeling_example = {"text": "The sky is blue."}
|
146 |
+
# Conversational format
|
147 |
+
language_modeling_example = {"messages": [
|
148 |
+
{"role": "user", "content": "What color is the sky?"},
|
149 |
+
{"role": "assistant", "content": "It is blue."}
|
150 |
+
]}
|
151 |
+
```
|
152 |
+
|
153 |
+
#### Prompt-only
|
154 |
+
|
155 |
+
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
|
156 |
+
|
157 |
+
```python
|
158 |
+
# Standard format
|
159 |
+
prompt_only_example = {"prompt": "The sky is"}
|
160 |
+
# Conversational format
|
161 |
+
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
162 |
+
```
|
163 |
+
|
164 |
+
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
|
165 |
+
|
166 |
+
<Tip>
|
167 |
+
|
168 |
+
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
169 |
+
|
170 |
+
```python
|
171 |
+
from transformers import AutoTokenizer
|
172 |
+
from trl import apply_chat_template
|
173 |
+
|
174 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
175 |
+
|
176 |
+
# Example for prompt-only type
|
177 |
+
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
178 |
+
apply_chat_template(prompt_only_example, tokenizer)
|
179 |
+
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
180 |
+
|
181 |
+
# Example for language modeling type
|
182 |
+
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
183 |
+
apply_chat_template(lm_example, tokenizer)
|
184 |
+
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
185 |
+
```
|
186 |
+
|
187 |
+
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
188 |
+
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
189 |
+
|
190 |
+
</Tip>
|
191 |
+
|
192 |
+
#### Prompt-completion
|
193 |
+
|
194 |
+
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
|
195 |
+
|
196 |
+
```python
|
197 |
+
# Standard format
|
198 |
+
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
|
199 |
+
# Conversational format
|
200 |
+
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
201 |
+
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
202 |
+
```
|
203 |
+
|
204 |
+
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
|
205 |
+
|
206 |
+
#### Preference
|
207 |
+
|
208 |
+
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
|
209 |
+
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
|
210 |
+
|
211 |
+
```python
|
212 |
+
# Standard format
|
213 |
+
## Explicit prompt (recommended)
|
214 |
+
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
215 |
+
# Implicit prompt
|
216 |
+
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
217 |
+
|
218 |
+
# Conversational format
|
219 |
+
## Explicit prompt (recommended)
|
220 |
+
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
221 |
+
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
222 |
+
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
223 |
+
## Implicit prompt
|
224 |
+
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
|
225 |
+
{"role": "assistant", "content": "It is blue."}],
|
226 |
+
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
227 |
+
{"role": "assistant", "content": "It is green."}]}
|
228 |
+
```
|
229 |
+
|
230 |
+
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
|
231 |
+
|
232 |
+
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
|
233 |
+
|
234 |
+
#### Unpaired preference
|
235 |
+
|
236 |
+
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
|
237 |
+
|
238 |
+
```python
|
239 |
+
# Standard format
|
240 |
+
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
241 |
+
# Conversational format
|
242 |
+
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
243 |
+
"completion": [{"role": "assistant", "content": "It is blue."}],
|
244 |
+
"label": True}
|
245 |
+
```
|
246 |
+
|
247 |
+
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
|
248 |
+
|
249 |
+
#### Stepwise supervision
|
250 |
+
|
251 |
+
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
|
252 |
+
|
253 |
+
```python
|
254 |
+
stepwise_example = {
|
255 |
+
"prompt": "Which number is larger, 9.8 or 9.11?",
|
256 |
+
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
|
257 |
+
"labels": [True, False]
|
258 |
+
}
|
259 |
+
```
|
260 |
+
|
261 |
+
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
|
262 |
+
|
263 |
+
## Which dataset type to use?
|
264 |
+
|
265 |
+
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
266 |
+
|
267 |
+
| Trainer | Expected dataset type |
|
268 |
+
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
|
269 |
+
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
|
270 |
+
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
271 |
+
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
272 |
+
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
273 |
+
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
274 |
+
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
275 |
+
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
276 |
+
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
277 |
+
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
278 |
+
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
279 |
+
| [`PPOTrainer`] | Tokenized language modeling |
|
280 |
+
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
281 |
+
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
282 |
+
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
283 |
+
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
284 |
+
|
285 |
+
<Tip>
|
286 |
+
|
287 |
+
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
288 |
+
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
289 |
+
|
290 |
+
</Tip>
|
291 |
+
|
292 |
+
## Working with conversational datasets in TRL
|
293 |
+
|
294 |
+
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
|
295 |
+
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
|
296 |
+
|
297 |
+
### Converting a conversational dataset into a standard dataset
|
298 |
+
|
299 |
+
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
300 |
+
|
301 |
+
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
302 |
+
|
303 |
+
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
|
304 |
+
|
305 |
+
```python
|
306 |
+
from transformers import AutoTokenizer
|
307 |
+
from trl import apply_chat_template
|
308 |
+
|
309 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
310 |
+
|
311 |
+
example = {
|
312 |
+
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
313 |
+
"completion": [{"role": "assistant", "content": "It is blue."}]
|
314 |
+
}
|
315 |
+
|
316 |
+
apply_chat_template(example, tokenizer)
|
317 |
+
# Output:
|
318 |
+
# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
|
319 |
+
```
|
320 |
+
|
321 |
+
Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
|
322 |
+
|
323 |
+
```python
|
324 |
+
from datasets import Dataset
|
325 |
+
from trl import apply_chat_template
|
326 |
+
|
327 |
+
dataset_dict = {
|
328 |
+
"prompt": [[{"role": "user", "content": "What color is the sky?"}],
|
329 |
+
[{"role": "user", "content": "Where is the sun?"}]],
|
330 |
+
"completion": [[{"role": "assistant", "content": "It is blue."}],
|
331 |
+
[{"role": "assistant", "content": "In the sky."}]]
|
332 |
+
}
|
333 |
+
|
334 |
+
dataset = Dataset.from_dict(dataset_dict)
|
335 |
+
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
336 |
+
# Output:
|
337 |
+
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
|
338 |
+
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
|
339 |
+
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
|
340 |
+
```
|
341 |
+
|
342 |
+
<Tip warning={true}>
|
343 |
+
|
344 |
+
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
|
345 |
+
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
346 |
+
|
347 |
+
</Tip>
|
348 |
+
|
349 |
+
<Tip warning={true}>
|
350 |
+
|
351 |
+
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
|
352 |
+
|
353 |
+
```python
|
354 |
+
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
|
355 |
+
# Output:
|
356 |
+
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
|
357 |
+
# 'completion': 'It is blue.<|im_end|>\n'}
|
358 |
+
```
|
359 |
+
|
360 |
+
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
|
361 |
+
|
362 |
+
</Tip>
|
363 |
+
|
364 |
+
## Using any dataset with TRL: preprocessing and conversion
|
365 |
+
|
366 |
+
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
|
367 |
+
|
368 |
+
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
|
369 |
+
|
370 |
+
### Example: UltraFeedback dataset
|
371 |
+
|
372 |
+
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
|
373 |
+
|
374 |
+
<iframe
|
375 |
+
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
|
376 |
+
frameborder="0"
|
377 |
+
width="100%"
|
378 |
+
height="560px"
|
379 |
+
></iframe>
|
380 |
+
|
381 |
+
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
|
382 |
+
|
383 |
+
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
|
384 |
+
|
385 |
+
```sh
|
386 |
+
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
|
387 |
+
```
|
388 |
+
|
389 |
+
Once converted, the dataset will look like this:
|
390 |
+
|
391 |
+
<iframe
|
392 |
+
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
|
393 |
+
frameborder="0"
|
394 |
+
width="100%"
|
395 |
+
height="560px"
|
396 |
+
></iframe>
|
397 |
+
|
398 |
+
Now, you can use this dataset with TRL!
|
399 |
+
|
400 |
+
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
|
401 |
+
|
402 |
+
## Utilities for converting dataset types
|
403 |
+
|
404 |
+
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
|
405 |
+
|
406 |
+
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
407 |
+
|
408 |
+
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
409 |
+
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
|
410 |
+
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
411 |
+
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
412 |
+
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
413 |
+
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
414 |
+
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
415 |
+
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
416 |
+
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
417 |
+
|
418 |
+
### From prompt-completion to language modeling dataset
|
419 |
+
|
420 |
+
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
|
421 |
+
|
422 |
+
```python
|
423 |
+
from datasets import Dataset
|
424 |
+
|
425 |
+
dataset = Dataset.from_dict({
|
426 |
+
"prompt": ["The sky is", "The sun is"],
|
427 |
+
"completion": [" blue.", " in the sky."],
|
428 |
+
})
|
429 |
+
|
430 |
+
def concat_prompt_completion(example):
|
431 |
+
return {"text": example["prompt"] + example["completion"]}
|
432 |
+
|
433 |
+
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
434 |
+
```
|
435 |
+
|
436 |
+
```python
|
437 |
+
>>> dataset[0]
|
438 |
+
{'text': 'The sky is blue.'}
|
439 |
+
```
|
440 |
+
|
441 |
+
### From prompt-completion to prompt-only dataset
|
442 |
+
|
443 |
+
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
|
444 |
+
|
445 |
+
```python
|
446 |
+
from datasets import Dataset
|
447 |
+
|
448 |
+
dataset = Dataset.from_dict({
|
449 |
+
"prompt": ["The sky is", "The sun is"],
|
450 |
+
"completion": [" blue.", " in the sky."],
|
451 |
+
})
|
452 |
+
|
453 |
+
dataset = dataset.remove_columns("completion")
|
454 |
+
```
|
455 |
+
|
456 |
+
```python
|
457 |
+
>>> dataset[0]
|
458 |
+
{'prompt': 'The sky is'}
|
459 |
+
```
|
460 |
+
|
461 |
+
### From preference with implicit prompt to language modeling dataset
|
462 |
+
|
463 |
+
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
|
464 |
+
|
465 |
+
```python
|
466 |
+
from datasets import Dataset
|
467 |
+
|
468 |
+
dataset = Dataset.from_dict({
|
469 |
+
"chosen": ["The sky is blue.", "The sun is in the sky."],
|
470 |
+
"rejected": ["The sky is green.", "The sun is in the sea."],
|
471 |
+
})
|
472 |
+
|
473 |
+
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
|
474 |
+
```
|
475 |
+
|
476 |
+
```python
|
477 |
+
>>> dataset[0]
|
478 |
+
{'text': 'The sky is blue.'}
|
479 |
+
```
|
480 |
+
|
481 |
+
### From preference with implicit prompt to prompt-completion dataset
|
482 |
+
|
483 |
+
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
484 |
+
|
485 |
+
```python
|
486 |
+
from datasets import Dataset
|
487 |
+
from trl import extract_prompt
|
488 |
+
|
489 |
+
dataset = Dataset.from_dict({
|
490 |
+
"chosen": [
|
491 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
492 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
493 |
+
],
|
494 |
+
"rejected": [
|
495 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
496 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
497 |
+
],
|
498 |
+
})
|
499 |
+
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
|
500 |
+
```
|
501 |
+
|
502 |
+
```python
|
503 |
+
>>> dataset[0]
|
504 |
+
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
|
505 |
+
```
|
506 |
+
|
507 |
+
### From preference with implicit prompt to prompt-only dataset
|
508 |
+
|
509 |
+
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
|
510 |
+
|
511 |
+
```python
|
512 |
+
from datasets import Dataset
|
513 |
+
from trl import extract_prompt
|
514 |
+
|
515 |
+
dataset = Dataset.from_dict({
|
516 |
+
"chosen": [
|
517 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
518 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
519 |
+
],
|
520 |
+
"rejected": [
|
521 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
522 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
523 |
+
],
|
524 |
+
})
|
525 |
+
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
|
526 |
+
```
|
527 |
+
|
528 |
+
```python
|
529 |
+
>>> dataset[0]
|
530 |
+
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
|
531 |
+
```
|
532 |
+
|
533 |
+
### From implicit to explicit prompt preference dataset
|
534 |
+
|
535 |
+
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
|
536 |
+
|
537 |
+
```python
|
538 |
+
from datasets import Dataset
|
539 |
+
from trl import extract_prompt
|
540 |
+
|
541 |
+
dataset = Dataset.from_dict({
|
542 |
+
"chosen": [
|
543 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
544 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
545 |
+
],
|
546 |
+
"rejected": [
|
547 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
548 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
549 |
+
],
|
550 |
+
})
|
551 |
+
|
552 |
+
dataset = dataset.map(extract_prompt)
|
553 |
+
```
|
554 |
+
|
555 |
+
```python
|
556 |
+
>>> dataset[0]
|
557 |
+
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
558 |
+
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
|
559 |
+
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
|
560 |
+
```
|
561 |
+
|
562 |
+
### From preference with implicit prompt to unpaired preference dataset
|
563 |
+
|
564 |
+
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
|
565 |
+
|
566 |
+
```python
|
567 |
+
from datasets import Dataset
|
568 |
+
from trl import extract_prompt, unpair_preference_dataset
|
569 |
+
|
570 |
+
dataset = Dataset.from_dict({
|
571 |
+
"chosen": [
|
572 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
573 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
574 |
+
],
|
575 |
+
"rejected": [
|
576 |
+
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
577 |
+
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
578 |
+
],
|
579 |
+
})
|
580 |
+
|
581 |
+
dataset = dataset.map(extract_prompt)
|
582 |
+
dataset = unpair_preference_dataset(dataset)
|
583 |
+
```
|
584 |
+
|
585 |
+
```python
|
586 |
+
>>> dataset[0]
|
587 |
+
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
588 |
+
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
589 |
+
'label': True}
|
590 |
+
```
|
591 |
+
|
592 |
+
<Tip warning={true}>
|
593 |
+
|
594 |
+
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
595 |
+
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
596 |
+
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
597 |
+
|
598 |
+
</Tip>
|
599 |
+
|
600 |
+
### From preference to language modeling dataset
|
601 |
+
|
602 |
+
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
|
603 |
+
|
604 |
+
```python
|
605 |
+
from datasets import Dataset
|
606 |
+
|
607 |
+
dataset = Dataset.from_dict({
|
608 |
+
"prompt": ["The sky is", "The sun is"],
|
609 |
+
"chosen": [" blue.", " in the sky."],
|
610 |
+
"rejected": [" green.", " in the sea."],
|
611 |
+
})
|
612 |
+
|
613 |
+
def concat_prompt_chosen(example):
|
614 |
+
return {"text": example["prompt"] + example["chosen"]}
|
615 |
+
|
616 |
+
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
|
617 |
+
```
|
618 |
+
|
619 |
+
```python
|
620 |
+
>>> dataset[0]
|
621 |
+
{'text': 'The sky is blue.'}
|
622 |
+
```
|
623 |
+
|
624 |
+
### From preference to prompt-completion dataset
|
625 |
+
|
626 |
+
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
627 |
+
|
628 |
+
```python
|
629 |
+
from datasets import Dataset
|
630 |
+
|
631 |
+
dataset = Dataset.from_dict({
|
632 |
+
"prompt": ["The sky is", "The sun is"],
|
633 |
+
"chosen": [" blue.", " in the sky."],
|
634 |
+
"rejected": [" green.", " in the sea."],
|
635 |
+
})
|
636 |
+
|
637 |
+
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
|
638 |
+
```
|
639 |
+
|
640 |
+
```python
|
641 |
+
>>> dataset[0]
|
642 |
+
{'prompt': 'The sky is', 'completion': ' blue.'}
|
643 |
+
```
|
644 |
+
|
645 |
+
### From preference to prompt-only dataset
|
646 |
+
|
647 |
+
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
|
648 |
+
|
649 |
+
```python
|
650 |
+
from datasets import Dataset
|
651 |
+
|
652 |
+
dataset = Dataset.from_dict({
|
653 |
+
"prompt": ["The sky is", "The sun is"],
|
654 |
+
"chosen": [" blue.", " in the sky."],
|
655 |
+
"rejected": [" green.", " in the sea."],
|
656 |
+
})
|
657 |
+
|
658 |
+
dataset = dataset.remove_columns(["chosen", "rejected"])
|
659 |
+
```
|
660 |
+
|
661 |
+
```python
|
662 |
+
>>> dataset[0]
|
663 |
+
{'prompt': 'The sky is'}
|
664 |
+
```
|
665 |
+
|
666 |
+
### From explicit to implicit prompt preference dataset
|
667 |
+
|
668 |
+
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
|
669 |
+
|
670 |
+
```python
|
671 |
+
from datasets import Dataset
|
672 |
+
|
673 |
+
dataset = Dataset.from_dict({
|
674 |
+
"prompt": [
|
675 |
+
[{"role": "user", "content": "What color is the sky?"}],
|
676 |
+
[{"role": "user", "content": "Where is the sun?"}],
|
677 |
+
],
|
678 |
+
"chosen": [
|
679 |
+
[{"role": "assistant", "content": "It is blue."}],
|
680 |
+
[{"role": "assistant", "content": "In the sky."}],
|
681 |
+
],
|
682 |
+
"rejected": [
|
683 |
+
[{"role": "assistant", "content": "It is green."}],
|
684 |
+
[{"role": "assistant", "content": "In the sea."}],
|
685 |
+
],
|
686 |
+
})
|
687 |
+
|
688 |
+
def concat_prompt_to_completions(example):
|
689 |
+
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
|
690 |
+
|
691 |
+
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
|
692 |
+
```
|
693 |
+
|
694 |
+
```python
|
695 |
+
>>> dataset[0]
|
696 |
+
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
|
697 |
+
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
|
698 |
+
```
|
699 |
+
|
700 |
+
### From preference to unpaired preference dataset
|
701 |
+
|
702 |
+
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
|
703 |
+
|
704 |
+
```python
|
705 |
+
from datasets import Dataset
|
706 |
+
from trl import unpair_preference_dataset
|
707 |
+
|
708 |
+
dataset = Dataset.from_dict({
|
709 |
+
"prompt": [
|
710 |
+
[{"role": "user", "content": "What color is the sky?"}],
|
711 |
+
[{"role": "user", "content": "Where is the sun?"}],
|
712 |
+
],
|
713 |
+
"chosen": [
|
714 |
+
[{"role": "assistant", "content": "It is blue."}],
|
715 |
+
[{"role": "assistant", "content": "In the sky."}],
|
716 |
+
],
|
717 |
+
"rejected": [
|
718 |
+
[{"role": "assistant", "content": "It is green."}],
|
719 |
+
[{"role": "assistant", "content": "In the sea."}],
|
720 |
+
],
|
721 |
+
})
|
722 |
+
|
723 |
+
dataset = unpair_preference_dataset(dataset)
|
724 |
+
```
|
725 |
+
|
726 |
+
```python
|
727 |
+
>>> dataset[0]
|
728 |
+
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
729 |
+
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
730 |
+
'label': True}
|
731 |
+
```
|
732 |
+
|
733 |
+
<Tip warning={true}>
|
734 |
+
|
735 |
+
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
736 |
+
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
737 |
+
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
738 |
+
|
739 |
+
</Tip>
|
740 |
+
|
741 |
+
### From unpaired preference to language modeling dataset
|
742 |
+
|
743 |
+
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
|
744 |
+
|
745 |
+
```python
|
746 |
+
from datasets import Dataset
|
747 |
+
|
748 |
+
dataset = Dataset.from_dict({
|
749 |
+
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
750 |
+
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
751 |
+
"label": [True, True, False, False],
|
752 |
+
})
|
753 |
+
|
754 |
+
def concatenate_prompt_completion(example):
|
755 |
+
return {"text": example["prompt"] + example["completion"]}
|
756 |
+
|
757 |
+
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
758 |
+
```
|
759 |
+
|
760 |
+
```python
|
761 |
+
>>> dataset[0]
|
762 |
+
{'text': 'The sky is blue.'}
|
763 |
+
```
|
764 |
+
|
765 |
+
### From unpaired preference to prompt-completion dataset
|
766 |
+
|
767 |
+
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
|
768 |
+
|
769 |
+
```python
|
770 |
+
from datasets import Dataset
|
771 |
+
|
772 |
+
dataset = Dataset.from_dict({
|
773 |
+
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
774 |
+
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
775 |
+
"label": [True, True, False, False],
|
776 |
+
})
|
777 |
+
|
778 |
+
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
|
779 |
+
```
|
780 |
+
|
781 |
+
```python
|
782 |
+
>>> dataset[0]
|
783 |
+
{'prompt': 'The sky is', 'completion': ' blue.'}
|
784 |
+
```
|
785 |
+
|
786 |
+
### From unpaired preference to prompt-only dataset
|
787 |
+
|
788 |
+
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
|
789 |
+
|
790 |
+
```python
|
791 |
+
from datasets import Dataset
|
792 |
+
|
793 |
+
dataset = Dataset.from_dict({
|
794 |
+
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
795 |
+
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
796 |
+
"label": [True, True, False, False],
|
797 |
+
})
|
798 |
+
|
799 |
+
dataset = dataset.remove_columns(["completion", "label"])
|
800 |
+
```
|
801 |
+
|
802 |
+
```python
|
803 |
+
>>> dataset[0]
|
804 |
+
{'prompt': 'The sky is'}
|
805 |
+
```
|
806 |
+
|
807 |
+
### From stepwise supervision to language modeling dataset
|
808 |
+
|
809 |
+
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
|
810 |
+
|
811 |
+
```python
|
812 |
+
from datasets import Dataset
|
813 |
+
|
814 |
+
dataset = Dataset.from_dict({
|
815 |
+
"prompt": ["Blue light", "Water"],
|
816 |
+
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
817 |
+
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
818 |
+
"labels": [[True, False], [True, True]],
|
819 |
+
})
|
820 |
+
|
821 |
+
def concatenate_prompt_completions(example):
|
822 |
+
completion = "".join(example["completions"])
|
823 |
+
return {"text": example["prompt"] + completion}
|
824 |
+
|
825 |
+
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
|
826 |
+
```
|
827 |
+
|
828 |
+
```python
|
829 |
+
>>> dataset[0]
|
830 |
+
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
|
831 |
+
```
|
832 |
+
|
833 |
+
### From stepwise supervision to prompt completion dataset
|
834 |
+
|
835 |
+
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
|
836 |
+
|
837 |
+
```python
|
838 |
+
from datasets import Dataset
|
839 |
+
|
840 |
+
dataset = Dataset.from_dict({
|
841 |
+
"prompt": ["Blue light", "Water"],
|
842 |
+
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
843 |
+
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
844 |
+
"labels": [[True, False], [True, True]],
|
845 |
+
})
|
846 |
+
|
847 |
+
def join_completions(example):
|
848 |
+
completion = "".join(example["completions"])
|
849 |
+
return {"completion": completion}
|
850 |
+
|
851 |
+
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
|
852 |
+
```
|
853 |
+
|
854 |
+
```python
|
855 |
+
>>> dataset[0]
|
856 |
+
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
|
857 |
+
```
|
858 |
+
|
859 |
+
### From stepwise supervision to prompt only dataset
|
860 |
+
|
861 |
+
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
|
862 |
+
|
863 |
+
```python
|
864 |
+
from datasets import Dataset
|
865 |
+
|
866 |
+
dataset = Dataset.from_dict({
|
867 |
+
"prompt": ["Blue light", "Water"],
|
868 |
+
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
869 |
+
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
870 |
+
"labels": [[True, False], [True, True]],
|
871 |
+
})
|
872 |
+
|
873 |
+
dataset = dataset.remove_columns(["completions", "labels"])
|
874 |
+
```
|
875 |
+
|
876 |
+
```python
|
877 |
+
>>> dataset[0]
|
878 |
+
{'prompt': 'Blue light'}
|
879 |
+
```
|
880 |
+
|
881 |
+
### From stepwise supervision to unpaired preference dataset
|
882 |
+
|
883 |
+
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
|
884 |
+
|
885 |
+
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
|
886 |
+
|
887 |
+
```python
|
888 |
+
from datasets import Dataset
|
889 |
+
|
890 |
+
dataset = Dataset.from_dict({
|
891 |
+
"prompt": ["Blue light", "Water"],
|
892 |
+
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
893 |
+
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
894 |
+
"labels": [[True, False], [True, True]],
|
895 |
+
})
|
896 |
+
|
897 |
+
def merge_completions_and_labels(example):
|
898 |
+
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
|
899 |
+
|
900 |
+
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
|
901 |
+
```
|
902 |
+
|
903 |
+
```python
|
904 |
+
>>> dataset[0]
|
905 |
+
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
|
906 |
+
```
|
907 |
+
|
908 |
+
## Vision datasets
|
909 |
+
|
910 |
+
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
|
911 |
+
|
912 |
+
A conversational vision dataset differs from a standard conversational dataset in two key ways:
|
913 |
+
|
914 |
+
1. The dataset must contain the key `images` with the image data.
|
915 |
+
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
|
916 |
+
|
917 |
+
Example:
|
918 |
+
|
919 |
+
```python
|
920 |
+
# Textual dataset:
|
921 |
+
"content": "What color is the sky?"
|
922 |
+
|
923 |
+
# Vision dataset:
|
924 |
+
"content": [
|
925 |
+
{"type": "image"},
|
926 |
+
{"type": "text", "text": "What color is the sky in the image?"}
|
927 |
+
]
|
928 |
+
```
|
929 |
+
|
930 |
+
An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
|
931 |
+
|
932 |
+
<iframe
|
933 |
+
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
|
934 |
+
frameborder="0"
|
935 |
+
width="100%"
|
936 |
+
height="560px"
|
937 |
+
></iframe>
|
938 |
+
|
docs/source/ddpo_trainer.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Denoising Diffusion Policy Optimization
|
2 |
+
|
3 |
+
[](https://huggingface.co/models?other=ddpo,trl)
|
4 |
+
|
5 |
+
## The why
|
6 |
+
|
7 |
+
| Before | After DDPO finetuning |
|
8 |
+
| --- | --- |
|
9 |
+
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_squirrel.png"/></div> |
|
10 |
+
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_crab.png"/></div> |
|
11 |
+
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_starfish.png"/></div> |
|
12 |
+
|
13 |
+
|
14 |
+
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
15 |
+
|
16 |
+
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
17 |
+
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
18 |
+
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
|
19 |
+
|
20 |
+
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
21 |
+
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
22 |
+
|
23 |
+
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
24 |
+
|
25 |
+
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
26 |
+
|
27 |
+
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
28 |
+
|
29 |
+
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
30 |
+
|
31 |
+
## Getting started with `examples/scripts/ddpo.py`
|
32 |
+
|
33 |
+
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
34 |
+
|
35 |
+
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
36 |
+
|
37 |
+
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
38 |
+
|
39 |
+
```batch
|
40 |
+
python ddpo.py --hf_user_access_token <token>
|
41 |
+
```
|
42 |
+
|
43 |
+
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
44 |
+
|
45 |
+
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
46 |
+
|
47 |
+
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
|
48 |
+
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
|
49 |
+
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
|
50 |
+
|
51 |
+
## Setting up the image logging hook function
|
52 |
+
|
53 |
+
Expect the function to be given a list of lists of the form
|
54 |
+
```python
|
55 |
+
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
56 |
+
|
57 |
+
```
|
58 |
+
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
59 |
+
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
60 |
+
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
61 |
+
|
62 |
+
### Key terms
|
63 |
+
|
64 |
+
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
65 |
+
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
66 |
+
- `prompt` : The prompt is the text that is used to generate the image
|
67 |
+
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
68 |
+
- `image` : The image generated by the Stable Diffusion model
|
69 |
+
|
70 |
+
Example code for logging sampled images with `wandb` is given below.
|
71 |
+
|
72 |
+
```python
|
73 |
+
# for logging these images to wandb
|
74 |
+
|
75 |
+
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
76 |
+
# For the sake of this example, we only care about the last batch
|
77 |
+
# hence we extract the last element of the list
|
78 |
+
result = {}
|
79 |
+
images, prompts, _, rewards, _ = image_data[-1]
|
80 |
+
for i, image in enumerate(images):
|
81 |
+
pil = Image.fromarray(
|
82 |
+
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
83 |
+
)
|
84 |
+
pil = pil.resize((256, 256))
|
85 |
+
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
86 |
+
accelerate_logger.log_images(
|
87 |
+
result,
|
88 |
+
step=global_step,
|
89 |
+
)
|
90 |
+
|
91 |
+
```
|
92 |
+
|
93 |
+
### Using the finetuned model
|
94 |
+
|
95 |
+
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
96 |
+
|
97 |
+
```python
|
98 |
+
|
99 |
+
import torch
|
100 |
+
from trl import DefaultDDPOStableDiffusionPipeline
|
101 |
+
|
102 |
+
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
103 |
+
|
104 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
105 |
+
|
106 |
+
# memory optimization
|
107 |
+
pipeline.vae.to(device, torch.float16)
|
108 |
+
pipeline.text_encoder.to(device, torch.float16)
|
109 |
+
pipeline.unet.to(device, torch.float16)
|
110 |
+
|
111 |
+
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
112 |
+
results = pipeline(prompts)
|
113 |
+
|
114 |
+
for prompt, image in zip(prompts,results.images):
|
115 |
+
image.save(f"{prompt}.png")
|
116 |
+
|
117 |
+
```
|
118 |
+
|
119 |
+
## Credits
|
120 |
+
|
121 |
+
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
122 |
+
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
|
123 |
+
|
124 |
+
## DDPOTrainer
|
125 |
+
|
126 |
+
[[autodoc]] DDPOTrainer
|
127 |
+
|
128 |
+
## DDPOConfig
|
129 |
+
|
130 |
+
[[autodoc]] DDPOConfig
|
131 |
+
|
docs/source/deepspeed_integration.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DeepSpeed Integration
|
2 |
+
|
3 |
+
<Tip warning={true}>
|
4 |
+
|
5 |
+
Section under construction. Feel free to contribute!
|
6 |
+
|
7 |
+
</Tip>
|
8 |
+
|
9 |
+
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
|
10 |
+
|
11 |
+
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
|
12 |
+
|
13 |
+

|
14 |
+
|
15 |
+
## Installation
|
16 |
+
|
17 |
+
To use DeepSpeed with TRL, install it using the following command:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
pip install deepspeed
|
21 |
+
```
|
22 |
+
|
23 |
+
## Running Training Scripts with DeepSpeed
|
24 |
+
|
25 |
+
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
|
26 |
+
|
27 |
+
```bash
|
28 |
+
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
|
29 |
+
```
|
30 |
+
|
31 |
+
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
|
32 |
+
|
33 |
+
```bash
|
34 |
+
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
|
35 |
+
```
|
36 |
+
|
37 |
+
## Additional Resources
|
38 |
+
|
39 |
+
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
docs/source/detoxifying_a_lm.md
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Detoxifying a Language Model using PPO
|
2 |
+
|
3 |
+
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
4 |
+
|
5 |
+
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
6 |
+
|
7 |
+
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
8 |
+
|
9 |
+
| File | Description | Colab link |
|
10 |
+
|---|---| --- |
|
11 |
+
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
12 |
+
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
13 |
+
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
14 |
+
|
15 |
+
## Context
|
16 |
+
|
17 |
+
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
18 |
+
|
19 |
+
### Computing toxicity scores
|
20 |
+
|
21 |
+
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
22 |
+
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
23 |
+
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
24 |
+
|
25 |
+
### Selection of models
|
26 |
+
|
27 |
+
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
28 |
+
|
29 |
+
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
30 |
+
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
31 |
+
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
32 |
+
|
33 |
+
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
34 |
+
|
35 |
+
| Model | Mean toxicity score |
|
36 |
+
|---|---|
|
37 |
+
| `gpt2` | 0.01602 |
|
38 |
+
| `facebook/opt-350m` | 0.01628 |
|
39 |
+
| `bigscience/bloom-560m` | 0.00767 |
|
40 |
+
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
41 |
+
|
42 |
+
## Designing the problem
|
43 |
+
|
44 |
+
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
45 |
+
|
46 |
+
### Pre-processing the dataset
|
47 |
+
|
48 |
+
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
49 |
+
|
50 |
+
A `prompt` example:
|
51 |
+
```
|
52 |
+
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
|
53 |
+
```
|
54 |
+
And its `continuation` value:
|
55 |
+
```
|
56 |
+
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
|
57 |
+
```
|
58 |
+
|
59 |
+
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
60 |
+
```python
|
61 |
+
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
62 |
+
|
63 |
+
def filter_fn(sample):
|
64 |
+
toxicity = sample["prompt"]["toxicity"]
|
65 |
+
return toxicity is not None and toxicity > 0.3
|
66 |
+
|
67 |
+
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
68 |
+
```
|
69 |
+
|
70 |
+
### Reward function
|
71 |
+
|
72 |
+
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
73 |
+
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
74 |
+
```python
|
75 |
+
logits = toxicity_model(**toxicity_inputs).logits.float()
|
76 |
+
rewards = (logits[:, 0]).tolist()
|
77 |
+
```
|
78 |
+
|
79 |
+
### Impact of input prompts length
|
80 |
+
|
81 |
+
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
82 |
+
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
83 |
+
|
84 |
+
|
85 |
+
<div style="text-align: center">
|
86 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
|
87 |
+
</div>
|
88 |
+
|
89 |
+
### How to deal with OOM issues
|
90 |
+
|
91 |
+
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
92 |
+
|
93 |
+
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
94 |
+
|
95 |
+
```python
|
96 |
+
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
|
97 |
+
```
|
98 |
+
|
99 |
+
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
|
100 |
+
|
101 |
+
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
102 |
+
|
103 |
+
<div style="text-align: center">
|
104 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
|
105 |
+
</div>
|
106 |
+
|
107 |
+
```python
|
108 |
+
ref_model = create_reference_model(model, num_shared_layers=6)
|
109 |
+
trainer = PPOTrainer(..., ref_model=ref_model)
|
110 |
+
```
|
111 |
+
|
112 |
+
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
113 |
+
|
114 |
+
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
115 |
+
|
116 |
+
## Training the model!
|
117 |
+
|
118 |
+
We have decided to keep 3 models in total that correspond to our best models:
|
119 |
+
|
120 |
+
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
121 |
+
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
122 |
+
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
123 |
+
|
124 |
+
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
125 |
+
|
126 |
+
<div style="text-align: center">
|
127 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
|
128 |
+
</div>
|
129 |
+
|
130 |
+
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
131 |
+
|
132 |
+
<div style="text-align: center">
|
133 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
|
134 |
+
</div>
|
135 |
+
|
136 |
+
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
137 |
+
|
138 |
+
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
139 |
+
|
140 |
+
<div style="text-align: center">
|
141 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
|
142 |
+
</div>
|
143 |
+
|
144 |
+
## Results
|
145 |
+
|
146 |
+
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
147 |
+
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
148 |
+
|
149 |
+
| Model | Mean toxicity score | Std toxicity score |
|
150 |
+
| --- | --- | --- |
|
151 |
+
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
152 |
+
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
153 |
+
| --- | --- | --- |
|
154 |
+
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
155 |
+
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
156 |
+
| --- | --- | --- |
|
157 |
+
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
158 |
+
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
159 |
+
|
160 |
+
<div class="column" style="text-align:center">
|
161 |
+
<figure>
|
162 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
163 |
+
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
164 |
+
</figure>
|
165 |
+
</div>
|
166 |
+
|
167 |
+
Below are few generation examples of `gpt-j-6b-detox` model:
|
168 |
+
|
169 |
+
<div style="text-align: center">
|
170 |
+
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
171 |
+
</div>
|
172 |
+
|
173 |
+
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
174 |
+
|
175 |
+
### Discussions
|
176 |
+
|
177 |
+
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
178 |
+
|
179 |
+
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
180 |
+
|
181 |
+
### Limitations
|
182 |
+
|
183 |
+
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
184 |
+
|
185 |
+
## What is next?
|
186 |
+
|
187 |
+
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
docs/source/distributing_training.md
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributing Training
|
2 |
+
|
3 |
+
<Tip warning={true}>
|
4 |
+
Section under construction. Feel free to contribute!
|
5 |
+
</Tip>
|
6 |
+
|
7 |
+
## Multi-GPU Training with TRL
|
8 |
+
|
9 |
+
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
|
10 |
+
|
11 |
+
```bash
|
12 |
+
accelerate config
|
13 |
+
```
|
14 |
+
|
15 |
+
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
|
16 |
+
|
17 |
+
```bash
|
18 |
+
accelerate launch train.py
|
19 |
+
```
|
20 |
+
|
21 |
+
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
22 |
+
|
23 |
+
```shell
|
24 |
+
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
|
25 |
+
```
|
26 |
+
|
27 |
+
This automatically distributes the workload across all available GPUs.
|
28 |
+
|
29 |
+
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
30 |
+
- Processes its own batch of data
|
31 |
+
- Computes the loss and gradients for that batch
|
32 |
+
- Shares gradient updates across all GPUs
|
33 |
+
|
34 |
+

|
35 |
+
|
36 |
+
The effective batch size is calculated as:
|
37 |
+
|
38 |
+
$$
|
39 |
+
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
|
40 |
+
$$
|
41 |
+
|
42 |
+
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
|
43 |
+
|
44 |
+
Example, these configurations are equivalent, and should yield the same results:
|
45 |
+
|
46 |
+
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
|
47 |
+
| --- | --- | --- | --- |
|
48 |
+
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
|
49 |
+
| 1 | 4 | 8 | Lower memory usage, slower training |
|
50 |
+
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
|
51 |
+
|
52 |
+
<Tip>
|
53 |
+
|
54 |
+
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
|
55 |
+
|
56 |
+
</Tip>
|
57 |
+
|
58 |
+
## Multi-Nodes Training
|
59 |
+
|
60 |
+
We're working on a guide for multi-node training. Stay tuned! 🚀
|
docs/source/dpo_trainer.md
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DPO Trainer
|
2 |
+
|
3 |
+
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
4 |
+
|
5 |
+
## Overview
|
6 |
+
|
7 |
+
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
|
8 |
+
|
9 |
+
The abstract from the paper is the following:
|
10 |
+
|
11 |
+
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
|
12 |
+
|
13 |
+
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
14 |
+
|
15 |
+
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
|
16 |
+
|
17 |
+
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
|
18 |
+
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
19 |
+
|
20 |
+
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
|
21 |
+
|
22 |
+

|
23 |
+
|
24 |
+
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
25 |
+
|
26 |
+
## Quick start
|
27 |
+
|
28 |
+
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
29 |
+
|
30 |
+
<iframe
|
31 |
+
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
32 |
+
frameborder="0"
|
33 |
+
width="100%"
|
34 |
+
height="560px"
|
35 |
+
></iframe>
|
36 |
+
|
37 |
+
Below is the script to train the model:
|
38 |
+
|
39 |
+
```python
|
40 |
+
# train_dpo.py
|
41 |
+
from datasets import load_dataset
|
42 |
+
from trl import DPOConfig, DPOTrainer
|
43 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
44 |
+
|
45 |
+
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
47 |
+
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
48 |
+
|
49 |
+
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
|
50 |
+
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
51 |
+
trainer.train()
|
52 |
+
```
|
53 |
+
|
54 |
+
Execute the script using the following command:
|
55 |
+
|
56 |
+
```bash
|
57 |
+
accelerate launch train_dpo.py
|
58 |
+
```
|
59 |
+
|
60 |
+
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
61 |
+
|
62 |
+

|
63 |
+
|
64 |
+
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
65 |
+
|
66 |
+
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-DPO
|
67 |
+
<strong><span style="color: red;"><shirin_yamani>:</span></strong>
|
68 |
+
What is Huggingface?
|
69 |
+
|
70 |
+
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-DPO>:</span></strong>
|
71 |
+
Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
|
72 |
+
</code></pre>
|
73 |
+
|
74 |
+
## Expected dataset type
|
75 |
+
|
76 |
+
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
77 |
+
|
78 |
+
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
79 |
+
|
80 |
+
### Special considerations for vision-language models
|
81 |
+
|
82 |
+
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
|
83 |
+
|
84 |
+
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
|
85 |
+
|
86 |
+
```diff
|
87 |
+
- model = AutoModelForCausalLM.from_pretrained(model_id)
|
88 |
+
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
|
89 |
+
|
90 |
+
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
91 |
+
+ processor = AutoProcessor.from_pretrained(model_id)
|
92 |
+
|
93 |
+
trainer = DPOTrainer(
|
94 |
+
model,
|
95 |
+
args=training_args,
|
96 |
+
train_dataset=train_dataset,
|
97 |
+
- processing_class=tokenizer,
|
98 |
+
+ processing_class=processor,
|
99 |
+
)
|
100 |
+
```
|
101 |
+
|
102 |
+
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
|
103 |
+
|
104 |
+
|
105 |
+
## Example script
|
106 |
+
|
107 |
+
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
108 |
+
|
109 |
+
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
accelerate launch trl/scripts/dpo.py \
|
113 |
+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
114 |
+
--dataset_name trl-lib/ultrafeedback_binarized \
|
115 |
+
--num_train_epochs 1 \
|
116 |
+
--logging_steps 25 \
|
117 |
+
--output_dir Qwen2-0.5B-DPO
|
118 |
+
```
|
119 |
+
|
120 |
+
## Logged metrics
|
121 |
+
|
122 |
+
While training and evaluating we record the following reward metrics:
|
123 |
+
|
124 |
+
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
125 |
+
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
126 |
+
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
127 |
+
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
128 |
+
|
129 |
+
## Loss functions
|
130 |
+
|
131 |
+
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
|
132 |
+
|
133 |
+
| `loss_type=` | Description |
|
134 |
+
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
135 |
+
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
136 |
+
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
137 |
+
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
|
138 |
+
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
|
139 |
+
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
|
140 |
+
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
|
141 |
+
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
|
142 |
+
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
|
143 |
+
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
|
144 |
+
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
|
145 |
+
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
|
146 |
+
|
147 |
+
### Label smoothing
|
148 |
+
|
149 |
+
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
|
150 |
+
|
151 |
+
### Syncing the reference model
|
152 |
+
|
153 |
+
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
|
154 |
+
|
155 |
+
### RPO loss
|
156 |
+
|
157 |
+
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
|
158 |
+
|
159 |
+
### WPO loss
|
160 |
+
|
161 |
+
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
|
162 |
+
|
163 |
+
### LD-DPO loss
|
164 |
+
|
165 |
+
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
|
166 |
+
|
167 |
+
### For Mixture of Experts Models: Enabling the auxiliary loss
|
168 |
+
|
169 |
+
MOEs are the most efficient if the load is about equally distributed between experts.
|
170 |
+
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
171 |
+
|
172 |
+
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
173 |
+
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
174 |
+
|
175 |
+
## Accelerate DPO fine-tuning using `unsloth`
|
176 |
+
|
177 |
+
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
178 |
+
|
179 |
+
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
180 |
+
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
|
181 |
+
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
182 |
+
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
183 |
+
|
184 |
+
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
185 |
+
|
186 |
+
```diff
|
187 |
+
from datasets import load_dataset
|
188 |
+
from trl import DPOConfig, DPOTrainer
|
189 |
+
- from transformers import AutoModelForCausalLM, AutoTokenizer
|
190 |
+
+ from unsloth import FastLanguageModel
|
191 |
+
|
192 |
+
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
193 |
+
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
194 |
+
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
195 |
+
+ model = FastLanguageModel.get_peft_model(model)
|
196 |
+
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
197 |
+
|
198 |
+
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
|
199 |
+
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
|
200 |
+
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
201 |
+
trainer.train()
|
202 |
+
|
203 |
+
```
|
204 |
+
|
205 |
+
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
206 |
+
|
207 |
+
## Reference model considerations with PEFT
|
208 |
+
|
209 |
+
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
210 |
+
|
211 |
+
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
212 |
+
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
213 |
+
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
214 |
+
|
215 |
+
### Downsides to merging QLoRA before DPO (approach 2)
|
216 |
+
|
217 |
+
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
218 |
+
|
219 |
+
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
220 |
+
|
221 |
+
### Using option 3 - load the adapter twice
|
222 |
+
|
223 |
+
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
|
224 |
+
|
225 |
+
For example:
|
226 |
+
|
227 |
+
```python
|
228 |
+
# Load the base model.
|
229 |
+
bnb_config = BitsAndBytesConfig(
|
230 |
+
load_in_4bit=True,
|
231 |
+
llm_int8_threshold=6.0,
|
232 |
+
llm_int8_has_fp16_weight=False,
|
233 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
234 |
+
bnb_4bit_use_double_quant=True,
|
235 |
+
bnb_4bit_quant_type="nf4",
|
236 |
+
)
|
237 |
+
model = AutoModelForCausalLM.from_pretrained(
|
238 |
+
"mistralai/mixtral-8x7b-v0.1",
|
239 |
+
load_in_4bit=True,
|
240 |
+
quantization_config=bnb_config,
|
241 |
+
attn_implementation="flash_attention_2",
|
242 |
+
torch_dtype=torch.bfloat16,
|
243 |
+
device_map="auto",
|
244 |
+
)
|
245 |
+
model.config.use_cache = False
|
246 |
+
|
247 |
+
# Load the adapter.
|
248 |
+
model = PeftModel.from_pretrained(
|
249 |
+
model,
|
250 |
+
"/path/to/peft",
|
251 |
+
is_trainable=True,
|
252 |
+
adapter_name="train",
|
253 |
+
)
|
254 |
+
# Load the adapter a second time, with a different name, which will be our reference model.
|
255 |
+
model.load_adapter("/path/to/peft", adapter_name="reference")
|
256 |
+
|
257 |
+
# Initialize the trainer, without a ref_model param.
|
258 |
+
training_args = DPOConfig(
|
259 |
+
model_adapter_name="train",
|
260 |
+
ref_adapter_name="reference",
|
261 |
+
)
|
262 |
+
dpo_trainer = DPOTrainer(
|
263 |
+
model,
|
264 |
+
args=training_args,
|
265 |
+
...
|
266 |
+
)
|
267 |
+
```
|
268 |
+
|
269 |
+
## DPOTrainer
|
270 |
+
|
271 |
+
[[autodoc]] DPOTrainer
|
272 |
+
|
273 |
+
## DPOConfig
|
274 |
+
|
275 |
+
[[autodoc]] DPOConfig
|
276 |
+
|
277 |
+
## DataCollatorForPreference
|
278 |
+
|
279 |
+
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
|
docs/source/example_overview.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Examples
|
2 |
+
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
|
6 |
+
The examples should work in any of the following settings (with the same script):
|
7 |
+
- single GPU
|
8 |
+
- multi GPUS (using PyTorch distributed mode)
|
9 |
+
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
10 |
+
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
11 |
+
|
12 |
+
To run it in each of these various modes, first initialize the accelerate
|
13 |
+
configuration with `accelerate config`
|
14 |
+
|
15 |
+
**NOTE to train with a 4-bit or 8-bit model**, please run
|
16 |
+
|
17 |
+
```bash
|
18 |
+
pip install --upgrade trl[quantization]
|
19 |
+
```
|
20 |
+
|
21 |
+
|
22 |
+
## Accelerate Config
|
23 |
+
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
|
24 |
+
|
25 |
+
```shell
|
26 |
+
accelerate config # will prompt you to define the training configuration
|
27 |
+
```
|
28 |
+
|
29 |
+
Then, it is encouraged to launch jobs with `accelerate launch`!
|
30 |
+
|
31 |
+
|
32 |
+
# Maintained Examples
|
33 |
+
|
34 |
+
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
35 |
+
|
36 |
+
| File | Description |
|
37 |
+
| --- | --- |
|
38 |
+
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
39 |
+
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
40 |
+
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
41 |
+
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
42 |
+
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
|
43 |
+
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
44 |
+
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
45 |
+
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
46 |
+
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
47 |
+
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
48 |
+
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
49 |
+
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
|
50 |
+
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
|
51 |
+
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
52 |
+
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
53 |
+
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
|
54 |
+
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
|
55 |
+
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
|
56 |
+
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
|
57 |
+
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
58 |
+
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
|
59 |
+
|
60 |
+
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
61 |
+
|
62 |
+
| File | Description |
|
63 |
+
| --- | --- |
|
64 |
+
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
65 |
+
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
66 |
+
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
67 |
+
|
68 |
+
|
69 |
+
We also have some other examples that are less maintained but can be used as a reference:
|
70 |
+
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
71 |
+
|
72 |
+
|
73 |
+
## Distributed training
|
74 |
+
|
75 |
+
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
|
76 |
+
|
77 |
+
```shell
|
78 |
+
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
79 |
+
```
|
80 |
+
|
81 |
+
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
|
82 |
+
|
83 |
+
### Distributed training with DeepSpeed
|
84 |
+
|
85 |
+
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
|
86 |
+
|
87 |
+
```shell
|
88 |
+
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
89 |
+
```
|