ivangabriele commited on
Commit
2f5127c
·
verified ·
0 Parent(s):

feat: initialize project

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +47 -0
  2. .github/ISSUE_TEMPLATE/bug-report.yml +67 -0
  3. .github/ISSUE_TEMPLATE/feature-request.yml +31 -0
  4. .github/ISSUE_TEMPLATE/new-trainer-addition.yml +32 -0
  5. .github/PULL_REQUEST_TEMPLATE.md +31 -0
  6. .github/codeql/custom-queries.qls +19 -0
  7. .github/workflows/build_documentation.yml +19 -0
  8. .github/workflows/build_pr_documentation.yml +19 -0
  9. .github/workflows/clear_cache.yml +33 -0
  10. .github/workflows/codeQL.yml +26 -0
  11. .github/workflows/docker-build.yml +95 -0
  12. .github/workflows/issue_auto_labeller.yml +15 -0
  13. .github/workflows/pr_style_bot.yml +127 -0
  14. .github/workflows/slow-tests.yml +98 -0
  15. .github/workflows/tests.yml +252 -0
  16. .github/workflows/tests_latest.yml +66 -0
  17. .github/workflows/trufflehog.yml +18 -0
  18. .github/workflows/upload_pr_documentation.yml +16 -0
  19. .gitignore +144 -0
  20. .pre-commit-config.yaml +17 -0
  21. CITATION.cff +34 -0
  22. CODE_OF_CONDUCT.md +133 -0
  23. CONTRIBUTING.md +767 -0
  24. Dockerfile +37 -0
  25. LICENSE +201 -0
  26. MANIFEST.in +6 -0
  27. Makefile +29 -0
  28. README.md +210 -0
  29. commands/run_dpo.sh +58 -0
  30. commands/run_sft.sh +59 -0
  31. docker-compose.yml +5 -0
  32. docker/trl-latest-gpu/Dockerfile +66 -0
  33. docker/trl-source-gpu/Dockerfile +66 -0
  34. docs/source/_toctree.yml +116 -0
  35. docs/source/alignprop_trainer.md +93 -0
  36. docs/source/bco_trainer.md +100 -0
  37. docs/source/best_of_n.md +72 -0
  38. docs/source/callbacks.md +21 -0
  39. docs/source/clis.md +272 -0
  40. docs/source/community_tutorials.md +32 -0
  41. docs/source/cpo_trainer.md +108 -0
  42. docs/source/customization.md +121 -0
  43. docs/source/data_utils.md +41 -0
  44. docs/source/dataset_formats.md +938 -0
  45. docs/source/ddpo_trainer.md +131 -0
  46. docs/source/deepspeed_integration.md +39 -0
  47. docs/source/detoxifying_a_lm.md +187 -0
  48. docs/source/distributing_training.md +60 -0
  49. docs/source/dpo_trainer.md +279 -0
  50. docs/source/example_overview.md +89 -0
.gitattributes ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.otf filter=lfs diff=lfs merge=lfs -text
37
+ *.eot filter=lfs diff=lfs merge=lfs -text
38
+ *.ttf filter=lfs diff=lfs merge=lfs -text
39
+ *.png filter=lfs diff=lfs merge=lfs -text
40
+ **/*.otf filter=lfs diff=lfs merge=lfs -text
41
+ **/*.eot filter=lfs diff=lfs merge=lfs -text
42
+ **/*.ttf filter=lfs diff=lfs merge=lfs -text
43
+ **/*.png filter=lfs diff=lfs merge=lfs -text
44
+ docs/**/*.otf filter=lfs diff=lfs merge=lfs -text
45
+ docs/**/*.eot filter=lfs diff=lfs merge=lfs -text
46
+ docs/**/*.ttf filter=lfs diff=lfs merge=lfs -text
47
+ docs/**/*.png filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug-report.yml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F41B Bug Report"
2
+ description: Submit a bug report to help us improve TRL
3
+ labels: [ "bug" ]
4
+ body:
5
+ - type: markdown
6
+ attributes:
7
+ value: |
8
+ Thanks for taking the time to fill out this bug report! 🤗
9
+
10
+ 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
11
+
12
+ - type: textarea
13
+ id: reproduction
14
+ validations:
15
+ required: true
16
+ attributes:
17
+ label: Reproduction
18
+ description: |
19
+ Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
20
+ If you have code snippets, error messages, stack traces please provide them here as well.
21
+ Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
22
+ Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
23
+
24
+ value: |
25
+ ```python
26
+ from trl import ...
27
+
28
+ ```
29
+
30
+ outputs:
31
+
32
+ ```
33
+ Traceback (most recent call last):
34
+ File "example.py", line 42, in <module>
35
+ ...
36
+ ```
37
+
38
+ - type: textarea
39
+ id: system-info
40
+ attributes:
41
+ label: System Info
42
+ description: |
43
+ Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
44
+ You can get this information by running `trl env` in your terminal.
45
+
46
+ placeholder: Copy-paste the output of `trl env`
47
+ validations:
48
+ required: true
49
+
50
+ - type: checkboxes
51
+ id: terms
52
+ attributes:
53
+ label: Checklist
54
+ description: |
55
+ Before submitting, please confirm that you've completed each of the following.
56
+ If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
57
+ options:
58
+ - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
59
+ required: true
60
+ - label: "I have included my system information"
61
+ required: true
62
+ - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
63
+ required: true
64
+ - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
65
+ required: true
66
+ - label: "Any traceback provided is complete"
67
+ required: true
.github/ISSUE_TEMPLATE/feature-request.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F680 Feature request"
2
+ description: Submit a proposal/request for a new TRL feature
3
+ labels: [ "Feature request" ]
4
+ body:
5
+ - type: textarea
6
+ id: feature-request
7
+ validations:
8
+ required: true
9
+ attributes:
10
+ label: Feature request
11
+ description: |
12
+ A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
13
+
14
+ - type: textarea
15
+ id: motivation
16
+ validations:
17
+ required: true
18
+ attributes:
19
+ label: Motivation
20
+ description: |
21
+ Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
22
+
23
+
24
+ - type: textarea
25
+ id: contribution
26
+ validations:
27
+ required: true
28
+ attributes:
29
+ label: Your contribution
30
+ description: |
31
+ Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
.github/ISSUE_TEMPLATE/new-trainer-addition.yml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "\U0001F31F New trainer addition"
2
+ description: Submit a proposal/request to implement a new trainer for a post-training method
3
+ labels: [ "New trainer" ]
4
+
5
+ body:
6
+ - type: textarea
7
+ id: description-request
8
+ validations:
9
+ required: true
10
+ attributes:
11
+ label: Method description
12
+ description: |
13
+ Put any and all important information relative to the method
14
+
15
+ - type: checkboxes
16
+ id: information-tasks
17
+ attributes:
18
+ label: Open source status
19
+ description: |
20
+ Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
21
+ options:
22
+ - label: "The method implementation is available"
23
+ - label: "The model weights are available"
24
+ - label: "The training datasets are available"
25
+
26
+ - type: textarea
27
+ id: additional-info
28
+ attributes:
29
+ label: Provide useful links for the implementation
30
+ description: |
31
+ Please provide information regarding the implementation, the weights, and the authors.
32
+ Please mention the authors by @gh-username if you're aware of their usernames.
.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # What does this PR do?
2
+
3
+ <!--
4
+ Congratulations! You've made it this far! You're not quite done yet though.
5
+
6
+ Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
7
+
8
+ Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
9
+
10
+ Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
11
+ -->
12
+
13
+ <!-- Remove if not applicable -->
14
+
15
+ Fixes # (issue)
16
+
17
+
18
+ ## Before submitting
19
+ - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
20
+ - [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
21
+ Pull Request section?
22
+ - [ ] Was this discussed/approved via a GitHub issue? Please add a link
23
+ to it if that's the case.
24
+ - [ ] Did you make sure to update the documentation with your changes?
25
+ - [ ] Did you write any new necessary tests?
26
+
27
+
28
+ ## Who can review?
29
+
30
+ Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
31
+ members/contributors who may be interested in your PR.
.github/codeql/custom-queries.qls ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import codeql
2
+
3
+ from WorkflowString interpolation, Workflow workflow
4
+ where
5
+ interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
6
+ interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
7
+ interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
8
+ interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
9
+ interpolation.getStringValue().matches("${{ github.event.review.body }}") or
10
+ interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
11
+ interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
12
+ interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
13
+ interpolation.getStringValue().matches("${{ github.event.* }}") and
14
+ (
15
+ step.getKey() = "run" or // Injection in run
16
+ step.getKey() = "env" or // Injection via env
17
+ step.getKey() = "with" // Injection via with
18
+ )
19
+ select workflow, "🚨 Do not use directly as input of action"
.github/workflows/build_documentation.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build documentation
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ - doc-builder*
8
+ - v*-release
9
+
10
+ jobs:
11
+ build:
12
+ uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
13
+ with:
14
+ commit_sha: ${{ github.sha }}
15
+ package: trl
16
+ version_tag_suffix: ""
17
+ custom_container: huggingface/transformers-doc-builder
18
+ secrets:
19
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
.github/workflows/build_pr_documentation.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build PR Documentation
2
+
3
+ on:
4
+ pull_request:
5
+
6
+ concurrency:
7
+ group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
8
+ cancel-in-progress: true
9
+
10
+ jobs:
11
+ build:
12
+ if: github.event.pull_request.draft == false
13
+ uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
14
+ with:
15
+ commit_sha: ${{ github.event.pull_request.head.sha }}
16
+ pr_number: ${{ github.event.number }}
17
+ package: trl
18
+ version_tag_suffix: ""
19
+ custom_container: huggingface/transformers-doc-builder
.github/workflows/clear_cache.yml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Cleanup Cache"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ schedule:
6
+ - cron: "0 0 * * *"
7
+
8
+ jobs:
9
+ cleanup:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - name: Check out code
13
+ uses: actions/checkout@v4
14
+
15
+ - name: Cleanup
16
+ run: |
17
+ gh extension install actions/gh-actions-cache
18
+
19
+ REPO=${{ github.repository }}
20
+
21
+ echo "Fetching list of cache key"
22
+ cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
23
+
24
+ ## Setting this to not fail the workflow while deleting cache keys.
25
+ set +e
26
+ echo "Deleting caches..."
27
+ for cacheKey in $cacheKeysForPR
28
+ do
29
+ gh actions-cache delete $cacheKey -R $REPO --confirm
30
+ done
31
+ echo "Done"
32
+ env:
33
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
.github/workflows/codeQL.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "CodeQL Analysis - Workflows"
2
+
3
+ on:
4
+ workflow_dispatch:
5
+
6
+ jobs:
7
+ analyze:
8
+ name: "Analyze GitHub Workflows"
9
+ runs-on: ubuntu-latest
10
+ permissions:
11
+ security-events: write
12
+ actions: read
13
+ contents: read
14
+
15
+ steps:
16
+ - name: "Checkout repository"
17
+ uses: actions/checkout@v4
18
+
19
+ - name: "Initialize CodeQL"
20
+ uses: github/codeql-action/init@v2
21
+ with:
22
+ languages: "yaml"
23
+ queries: +security-and-quality, ./.github/codeql/custom-queries.qls
24
+
25
+ - name: "Perform CodeQL Analysis"
26
+ uses: github/codeql-action/analyze@v2
.github/workflows/docker-build.yml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build Docker images (scheduled)
2
+
3
+ on:
4
+ workflow_dispatch:
5
+ workflow_call:
6
+ schedule:
7
+ - cron: "0 1 * * *"
8
+
9
+ concurrency:
10
+ group: docker-image-builds
11
+ cancel-in-progress: false
12
+
13
+ env:
14
+ CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
15
+
16
+ jobs:
17
+ trl-latest:
18
+ name: "Latest TRL GPU"
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - name: Cleanup disk
22
+ run: |
23
+ sudo ls -l /usr/local/lib/
24
+ sudo ls -l /usr/share/
25
+ sudo du -sh /usr/local/lib/
26
+ sudo du -sh /usr/share/
27
+ sudo rm -rf /usr/local/lib/android
28
+ sudo rm -rf /usr/share/dotnet
29
+ sudo du -sh /usr/local/lib/
30
+ sudo du -sh /usr/share/
31
+ - name: Set up Docker Buildx
32
+ uses: docker/setup-buildx-action@v1
33
+ - name: Check out code
34
+ uses: actions/checkout@v4
35
+ - name: Login to DockerHub
36
+ uses: docker/login-action@v1
37
+ with:
38
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
39
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
40
+
41
+ - name: Build and Push GPU
42
+ uses: docker/build-push-action@v4
43
+ with:
44
+ context: ./docker/trl-latest-gpu
45
+ push: true
46
+ tags: huggingface/trl-latest-gpu
47
+
48
+ - name: Post to Slack
49
+ if: always()
50
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
51
+ with:
52
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
53
+ title: 🤗 Results of the trl-latest-gpu Docker Image build
54
+ status: ${{ job.status }}
55
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
56
+
57
+ trl-source:
58
+ name: "Latest TRL + HF ecosystem from source"
59
+ runs-on: ubuntu-latest
60
+ steps:
61
+ - name: Cleanup disk
62
+ run: |
63
+ sudo ls -l /usr/local/lib/
64
+ sudo ls -l /usr/share/
65
+ sudo du -sh /usr/local/lib/
66
+ sudo du -sh /usr/share/
67
+ sudo rm -rf /usr/local/lib/android
68
+ sudo rm -rf /usr/share/dotnet
69
+ sudo du -sh /usr/local/lib/
70
+ sudo du -sh /usr/share/
71
+ - name: Set up Docker Buildx
72
+ uses: docker/setup-buildx-action@v1
73
+ - name: Check out code
74
+ uses: actions/checkout@v4
75
+ - name: Login to DockerHub
76
+ uses: docker/login-action@v1
77
+ with:
78
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
79
+ password: ${{ secrets.DOCKERHUB_PASSWORD }}
80
+
81
+ - name: Build and Push GPU
82
+ uses: docker/build-push-action@v4
83
+ with:
84
+ context: ./docker/trl-source-gpu
85
+ push: true
86
+ tags: huggingface/trl-source-gpu
87
+
88
+ - name: Post to Slack
89
+ if: always()
90
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
91
+ with:
92
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
93
+ title: 🤗 Results of the trl-source-gpu Docker Image build
94
+ status: ${{ job.status }}
95
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
.github/workflows/issue_auto_labeller.yml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "Hugging Face Issue Labeler"
2
+ on:
3
+ issues:
4
+ types: opened
5
+
6
+ jobs:
7
+ triage:
8
+ runs-on: ubuntu-latest
9
+ permissions:
10
+ issues: write
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ - uses: August-murr/auto-labeler@main
14
+ with:
15
+ hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
.github/workflows/pr_style_bot.yml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: PR Style Bot
2
+
3
+ on:
4
+ workflow_dispatch:
5
+
6
+
7
+ permissions:
8
+ contents: write
9
+ pull-requests: write
10
+
11
+ jobs:
12
+ run-style-bot:
13
+ if: >
14
+ contains(github.event.comment.body, '@bot /style') &&
15
+ github.event.issue.pull_request != null
16
+ runs-on: ubuntu-latest
17
+
18
+ steps:
19
+ - name: Extract PR details
20
+ id: pr_info
21
+ uses: actions/github-script@v6
22
+ with:
23
+ script: |
24
+ const prNumber = context.payload.issue.number;
25
+ const { data: pr } = await github.rest.pulls.get({
26
+ owner: context.repo.owner,
27
+ repo: context.repo.repo,
28
+ pull_number: prNumber
29
+ });
30
+
31
+ // We capture both the branch ref and the "full_name" of the head repo
32
+ // so that we can check out the correct repository & branch (including forks).
33
+ core.setOutput("prNumber", prNumber);
34
+ core.setOutput("headRef", pr.head.ref);
35
+ core.setOutput("headRepoFullName", pr.head.repo.full_name);
36
+
37
+ - name: Check out PR branch
38
+ uses: actions/checkout@v3
39
+ env:
40
+ HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
41
+ HEADREF: ${{ steps.pr_info.outputs.headRef }}
42
+ with:
43
+ # Instead of checking out the base repo, use the contributor's repo name
44
+ repository: ${{ env.HEADREPOFULLNAME }}
45
+ ref: ${{ env.HEADREF }}
46
+ # You may need fetch-depth: 0 for being able to push
47
+ fetch-depth: 0
48
+ token: ${{ secrets.GITHUB_TOKEN }}
49
+
50
+ - name: Debug
51
+ env:
52
+ HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
53
+ HEADREF: ${{ steps.pr_info.outputs.headRef }}
54
+ PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
55
+ run: |
56
+ echo "PR number: ${{ env.PRNUMBER }}"
57
+ echo "Head Ref: ${{ env.HEADREF }}"
58
+ echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
59
+
60
+ - name: Set up Python
61
+ uses: actions/setup-python@v4
62
+
63
+ - name: Install dependencies
64
+ run: |
65
+ pip install ruff pre-commit
66
+
67
+ - name: Download Makefile from main branch
68
+ run: |
69
+ curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
70
+
71
+ - name: Compare Makefiles
72
+ run: |
73
+ if ! diff -q main_Makefile Makefile; then
74
+ echo "Error: The Makefile has changed. Please ensure it matches the main branch."
75
+ exit 1
76
+ fi
77
+ echo "No changes in Makefile. Proceeding..."
78
+ rm -rf main_Makefile
79
+
80
+ - name: Run make style and make quality
81
+ run: |
82
+ make precommit || true
83
+
84
+ - name: Commit and push changes
85
+ id: commit_and_push
86
+ env:
87
+ HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
88
+ HEADREF: ${{ steps.pr_info.outputs.headRef }}
89
+ PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
90
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
91
+ run: |
92
+ echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
93
+ # Configure git with the Actions bot user
94
+ git config user.name "github-actions[bot]"
95
+ git config user.email "github-actions[bot]@users.noreply.github.com"
96
+
97
+ # Make sure your 'origin' remote is set to the contributor's fork
98
+ git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
99
+
100
+ # If there are changes after running style/quality, commit them
101
+ if [ -n "$(git status --porcelain)" ]; then
102
+ git add .
103
+ git commit -m "Apply style fixes"
104
+ # Push to the original contributor's forked branch
105
+ git push origin HEAD:${{ env.HEADREF }}
106
+ echo "changes_pushed=true" >> $GITHUB_OUTPUT
107
+ else
108
+ echo "No changes to commit."
109
+ echo "changes_pushed=false" >> $GITHUB_OUTPUT
110
+ fi
111
+
112
+ - name: Comment on PR with workflow run link
113
+ if: steps.commit_and_push.outputs.changes_pushed == 'true'
114
+ uses: actions/github-script@v6
115
+ with:
116
+ script: |
117
+ const prNumber = parseInt(process.env.prNumber, 10);
118
+ const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
119
+
120
+ await github.rest.issues.createComment({
121
+ owner: context.repo.owner,
122
+ repo: context.repo.repo,
123
+ issue_number: prNumber,
124
+ body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
125
+ });
126
+ env:
127
+ prNumber: ${{ steps.pr_info.outputs.prNumber }}
.github/workflows/slow-tests.yml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Slow tests (on push)
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ paths:
7
+ # Run only when python files are modified
8
+ - "trl/**.py"
9
+ - "examples/**.py"
10
+ env:
11
+ RUN_SLOW: "yes"
12
+ IS_GITHUB_CI: "1"
13
+ SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
14
+
15
+
16
+ jobs:
17
+ run_all_tests_single_gpu:
18
+ strategy:
19
+ fail-fast: false
20
+ matrix:
21
+ docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
22
+ runs-on:
23
+ group: aws-g4dn-2xlarge
24
+ env:
25
+ CUDA_VISIBLE_DEVICES: "0"
26
+ TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
27
+ container:
28
+ image: ${{ matrix.docker-image-name }}
29
+ options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
30
+ defaults:
31
+ run:
32
+ shell: bash
33
+ steps:
34
+ - uses: actions/checkout@v4
35
+ - name: Pip install
36
+ run: |
37
+ source activate trl
38
+ pip install -e ".[test]" --no-deps
39
+ pip install pytest-reportlog parameterized
40
+
41
+ - name: Run slow SFT tests on single GPU
42
+ if: always()
43
+ run: |
44
+ source activate trl
45
+ make slow_tests
46
+
47
+ - name: Generate Report
48
+ if: always()
49
+ run: |
50
+ pip install slack_sdk tabulate
51
+ python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
52
+
53
+
54
+ run_all_tests_multi_gpu:
55
+ strategy:
56
+ fail-fast: false
57
+ matrix:
58
+ docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
59
+ runs-on:
60
+ group: aws-g4dn-2xlarge
61
+ env:
62
+ CUDA_VISIBLE_DEVICES: "0,1"
63
+ TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
64
+ container:
65
+ image: ${{ matrix.docker-image-name }}
66
+ options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
67
+ defaults:
68
+ run:
69
+ shell: bash
70
+ steps:
71
+ - uses: actions/checkout@v4
72
+ - name: Pip install
73
+ run: |
74
+ source activate trl
75
+ pip install -e ".[test]" --no-deps
76
+ pip install pytest-reportlog parameterized
77
+
78
+ - name: Run slow SFT tests on Multi GPU
79
+ if: always()
80
+ run: |
81
+ source activate trl
82
+ make slow_tests
83
+
84
+ - name: Run end-to-end examples tests on multi GPU
85
+ if: always()
86
+ run: |
87
+ source activate trl
88
+ pip install deepspeed
89
+ make test_examples
90
+
91
+ - name: Generate Reports
92
+ if: always()
93
+ run: |
94
+ pip install slack_sdk tabulate
95
+ python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
96
+ python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
97
+ python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
98
+ rm *.txt
.github/workflows/tests.yml ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests
2
+
3
+ on:
4
+ push:
5
+ branches: [ main ]
6
+ pull_request:
7
+ paths:
8
+ # Run only when relevant files are modified
9
+ - ".github/**.yml"
10
+ - "examples/**.py"
11
+ - "scripts/**.py"
12
+ - "tests/**.py"
13
+ - "trl/**.py"
14
+ - "setup.py"
15
+
16
+ env:
17
+ TQDM_DISABLE: 1
18
+ CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
19
+
20
+ jobs:
21
+ check_code_quality:
22
+ name: Check code quality
23
+ runs-on: ubuntu-latest
24
+ if: github.event.pull_request.draft == false
25
+ steps:
26
+ - uses: actions/checkout@v4
27
+ - name: Set up Python 3.12
28
+ uses: actions/setup-python@v5
29
+ with:
30
+ python-version: 3.12
31
+ - uses: pre-commit/[email protected]
32
+ with:
33
+ extra_args: --all-files
34
+
35
+ tests:
36
+ name: Tests
37
+ strategy:
38
+ matrix:
39
+ python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
40
+ fail-fast: false
41
+ runs-on:
42
+ group: aws-g4dn-2xlarge
43
+ container:
44
+ image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
45
+ options: --gpus all
46
+ defaults:
47
+ run:
48
+ shell: bash
49
+ if: github.event.pull_request.draft == false
50
+ steps:
51
+ - name: Git checkout
52
+ uses: actions/checkout@v4
53
+
54
+ - name: Set up Python ${{ matrix.python-version }}
55
+ uses: actions/setup-python@v5
56
+ with:
57
+ python-version: ${{ matrix.python-version }}
58
+
59
+ - name: Install Make and Git
60
+ run: |
61
+ apt-get update && apt-get install -y make git curl
62
+
63
+ - name: Install uv
64
+ run: |
65
+ curl -LsSf https://astral.sh/uv/install.sh | sh
66
+
67
+ - name: Create Python virtual environment
68
+ run: |
69
+ uv venv
70
+ uv pip install --upgrade setuptools wheel
71
+
72
+ - name: Install dependencies
73
+ run: |
74
+ source .venv/bin/activate
75
+ uv pip install ".[dev]"
76
+
77
+ - name: Test with pytest
78
+ run: |
79
+ source .venv/bin/activate
80
+ make test
81
+
82
+ - name: Post to Slack
83
+ if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
84
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
85
+ with:
86
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
87
+ title: Results with Python ${{ matrix.python-version }} and latest dependencies
88
+ status: ${{ job.status }}
89
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
90
+
91
+ tests_dev:
92
+ name: Tests with dev dependencies
93
+ runs-on:
94
+ group: aws-g4dn-2xlarge
95
+ container:
96
+ image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
97
+ options: --gpus all
98
+ defaults:
99
+ run:
100
+ shell: bash
101
+ if: github.event.pull_request.draft == false
102
+ steps:
103
+ - name: Git checkout
104
+ uses: actions/checkout@v4
105
+
106
+ - name: Set up Python 3.12
107
+ uses: actions/setup-python@v5
108
+ with:
109
+ python-version: '3.12'
110
+
111
+ - name: Install Make and Git
112
+ run: |
113
+ apt-get update && apt-get install -y make git curl
114
+
115
+ - name: Install uv
116
+ run: |
117
+ curl -LsSf https://astral.sh/uv/install.sh | sh
118
+
119
+ - name: Create Python virtual environment
120
+ run: |
121
+ uv venv
122
+ uv pip install --upgrade setuptools wheel
123
+
124
+ - name: Install dependencies
125
+ run: |
126
+ source .venv/bin/activate
127
+ uv pip install ".[dev]"
128
+ uv pip install -U git+https://github.com/huggingface/accelerate.git
129
+ uv pip install -U git+https://github.com/huggingface/datasets.git
130
+ uv pip install -U git+https://github.com/huggingface/transformers.git
131
+
132
+
133
+ - name: Test with pytest
134
+ run: |
135
+ source .venv/bin/activate
136
+ make test
137
+
138
+ - name: Post to Slack
139
+ if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
140
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
141
+ with:
142
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
143
+ title: Results with Python 3.12 and dev dependencies
144
+ status: ${{ job.status }}
145
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
146
+
147
+ tests_wo_optional_deps:
148
+ name: Tests without optional dependencies
149
+ runs-on:
150
+ group: aws-g4dn-2xlarge
151
+ container:
152
+ image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
153
+ options: --gpus all
154
+ defaults:
155
+ run:
156
+ shell: bash
157
+ if: github.event.pull_request.draft == false
158
+ steps:
159
+ - name: Git checkout
160
+ uses: actions/checkout@v4
161
+
162
+ - name: Set up Python 3.12
163
+ uses: actions/setup-python@v5
164
+ with:
165
+ python-version: '3.12'
166
+
167
+ - name: Install Make and Git
168
+ run: |
169
+ apt-get update && apt-get install -y make git curl
170
+
171
+ - name: Install uv
172
+ run: |
173
+ curl -LsSf https://astral.sh/uv/install.sh | sh
174
+
175
+ - name: Create Python virtual environment
176
+ run: |
177
+ uv venv
178
+ uv pip install --upgrade setuptools wheel
179
+
180
+ - name: Install dependencies
181
+ run: |
182
+ source .venv/bin/activate
183
+ uv pip install ".[test]"
184
+
185
+ - name: Test with pytest
186
+ run: |
187
+ source .venv/bin/activate
188
+ make test
189
+
190
+ - name: Post to Slack
191
+ if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
192
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
193
+ with:
194
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
195
+ title: Results with Python 3.12 without optional dependencies
196
+ status: ${{ job.status }}
197
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
198
+
199
+ tests_min_versions:
200
+ name: Tests with minimum versions
201
+ runs-on:
202
+ group: aws-g4dn-2xlarge
203
+ container:
204
+ image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
205
+ options: --gpus all
206
+ defaults:
207
+ run:
208
+ shell: bash
209
+ if: github.event.pull_request.draft == false
210
+ steps:
211
+ - name: Git checkout
212
+ uses: actions/checkout@v4
213
+
214
+ - name: Set up Python 3.12
215
+ uses: actions/setup-python@v5
216
+ with:
217
+ python-version: '3.12'
218
+
219
+ - name: Install Make and Git
220
+ run: |
221
+ apt-get update && apt-get install -y make git curl
222
+
223
+ - name: Install uv
224
+ run: |
225
+ curl -LsSf https://astral.sh/uv/install.sh | sh
226
+
227
+ - name: Create Python virtual environment
228
+ run: |
229
+ uv venv
230
+ uv pip install --upgrade setuptools wheel
231
+
232
+ - name: Install dependencies
233
+ run: |
234
+ source .venv/bin/activate
235
+ uv pip install ".[dev]"
236
+ uv pip install accelerate==1.4.0
237
+ uv pip install datasets==3.0.0
238
+ uv pip install transformers==4.51.0
239
+
240
+ - name: Test with pytest
241
+ run: |
242
+ source .venv/bin/activate
243
+ make test
244
+
245
+ - name: Post to Slack
246
+ if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
247
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
248
+ with:
249
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
250
+ title: Results with Python 3.12 and minimum dependencies versions
251
+ status: ${{ job.status }}
252
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
.github/workflows/tests_latest.yml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Tests latest TRL release with dev dependencies
2
+
3
+ on:
4
+ schedule:
5
+ - cron: '0 0 * * *' # Runs daily at midnight UTC
6
+
7
+ workflow_dispatch:
8
+
9
+ env:
10
+ TQDM_DISABLE: 1
11
+ CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
12
+
13
+ jobs:
14
+ tests:
15
+ name: Tests latest TRL release with dev dependencies
16
+ runs-on:
17
+ group: aws-g4dn-2xlarge
18
+ container:
19
+ image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
20
+ options: --gpus all
21
+ defaults:
22
+ run:
23
+ shell: bash
24
+ steps:
25
+ - name: Git checkout
26
+ uses: actions/checkout@v4
27
+ with: { ref: v0.18-release }
28
+
29
+ - name: Set up Python 3.12
30
+ uses: actions/setup-python@v5
31
+ with:
32
+ python-version: '3.12'
33
+
34
+ - name: Install Make and Git
35
+ run: |
36
+ apt-get update && apt-get install -y make git curl
37
+
38
+ - name: Install uv
39
+ run: |
40
+ curl -LsSf https://astral.sh/uv/install.sh | sh
41
+
42
+ - name: Create Python virtual environment
43
+ run: |
44
+ uv venv
45
+ uv pip install --upgrade setuptools wheel
46
+
47
+ - name: Install dependencies
48
+ run: |
49
+ source .venv/bin/activate
50
+ uv pip install ".[dev]"
51
+ uv pip install -U git+https://github.com/huggingface/accelerate.git
52
+ uv pip install -U git+https://github.com/huggingface/datasets.git
53
+ uv pip install -U git+https://github.com/huggingface/transformers.git
54
+
55
+ - name: Test with pytest
56
+ run: |
57
+ source .venv/bin/activate
58
+ make test
59
+
60
+ - name: Post to Slack
61
+ uses: huggingface/hf-workflows/.github/actions/post-slack@main
62
+ with:
63
+ slack_channel: ${{ env.CI_SLACK_CHANNEL }}
64
+ title: Results of latest TRL with Python 3.12 and dev dependencies
65
+ status: ${{ job.status }}
66
+ slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
.github/workflows/trufflehog.yml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ on:
2
+ push:
3
+
4
+ name: Secret Leaks
5
+
6
+ jobs:
7
+ trufflehog:
8
+ runs-on: ubuntu-latest
9
+ steps:
10
+ - name: Checkout code
11
+ uses: actions/checkout@v4
12
+ with:
13
+ fetch-depth: 0
14
+ - name: Secret Scanning
15
+ uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
16
+ with:
17
+ # exclude buggy postgres detector that is causing false positives and not relevant to our codebase
18
+ extra_args: --results=verified,unknown --exclude-detectors=postgres
.github/workflows/upload_pr_documentation.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Upload PR Documentation
2
+
3
+ on:
4
+ workflow_run:
5
+ workflows: ["Build PR Documentation"]
6
+ types:
7
+ - completed
8
+
9
+ jobs:
10
+ build:
11
+ uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12
+ with:
13
+ package_name: trl
14
+ secrets:
15
+ hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16
+ comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
.gitignore ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.bak
2
+ .last_checked
3
+ .gitconfig
4
+ *.bak
5
+ *.log
6
+ *~
7
+ ~*
8
+ _tmp*
9
+ tmp*
10
+ tags
11
+
12
+ # Byte-compiled / optimized / DLL files
13
+ __pycache__/
14
+ *.py[cod]
15
+ *$py.class
16
+
17
+ # C extensions
18
+ *.so
19
+
20
+ # Distribution / packaging
21
+ .Python
22
+ env/
23
+ build/
24
+ develop-eggs/
25
+ dist/
26
+ downloads/
27
+ eggs/
28
+ .eggs/
29
+ lib/
30
+ lib64/
31
+ parts/
32
+ sdist/
33
+ var/
34
+ wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+
39
+ # PyInstaller
40
+ # Usually these files are written by a python script from a template
41
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
42
+ *.manifest
43
+ *.spec
44
+
45
+ # Installer logs
46
+ pip-log.txt
47
+ pip-delete-this-directory.txt
48
+
49
+ # Unit test / coverage reports
50
+ htmlcov/
51
+ .tox/
52
+ .coverage
53
+ .coverage.*
54
+ .cache
55
+ nosetests.xml
56
+ coverage.xml
57
+ *.cover
58
+ .hypothesis/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # celery beat schedule file
88
+ celerybeat-schedule
89
+
90
+ # SageMath parsed files
91
+ *.sage.py
92
+
93
+ # dotenv
94
+ .env
95
+
96
+ # virtualenv
97
+ .venv
98
+ venv/
99
+ ENV/
100
+
101
+ # Spyder project settings
102
+ .spyderproject
103
+ .spyproject
104
+
105
+ # Rope project settings
106
+ .ropeproject
107
+
108
+ # mkdocs documentation
109
+ /site
110
+
111
+ # mypy
112
+ .mypy_cache/
113
+
114
+ .vscode
115
+ *.swp
116
+
117
+ # osx generated files
118
+ .DS_Store
119
+ .DS_Store?
120
+ .Trashes
121
+ ehthumbs.db
122
+ Thumbs.db
123
+ .idea
124
+
125
+ # pytest
126
+ .pytest_cache
127
+
128
+ # tools/trust-doc-nbs
129
+ docs_src/.last_checked
130
+
131
+ # symlinks to fastai
132
+ docs_src/fastai
133
+ tools/fastai
134
+
135
+ # link checker
136
+ checklink/cookies.txt
137
+
138
+ # .gitconfig is now autogenerated
139
+ .gitconfig
140
+
141
+ # wandb files
142
+ nbs/wandb/
143
+ examples/notebooks/wandb/
144
+ wandb/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.11.10
4
+ hooks:
5
+ - id: ruff-check
6
+ types_or: [ python, pyi ]
7
+ args: [ --fix ]
8
+ - id: ruff-format
9
+ types_or: [ python, pyi ]
10
+
11
+ # - repo: https://github.com/codespell-project/codespell
12
+ # rev: v2.1.0
13
+ # hooks:
14
+ # - id: codespell
15
+ # args:
16
+ # - --ignore-words-list=nd,reacher,thist,ths,magent,ba
17
+ # - --skip=docs/css/termynal.css,docs/js/termynal.js
CITATION.cff ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: 'TRL: Transformer Reinforcement Learning'
3
+ message: >-
4
+ If you use this software, please cite it using the
5
+ metadata from this file.
6
+ type: software
7
+ authors:
8
+ - given-names: Leandro
9
+ family-names: von Werra
10
+ - given-names: Younes
11
+ family-names: Belkada
12
+ - given-names: Lewis
13
+ family-names: Tunstall
14
+ - given-names: Edward
15
+ family-names: Beeching
16
+ - given-names: Tristan
17
+ family-names: Thrush
18
+ - given-names: Nathan
19
+ family-names: Lambert
20
+ - given-names: Shengyi
21
+ family-names: Huang
22
+ - given-names: Kashif
23
+ family-names: Rasul
24
+ - given-names: Quentin
25
+ family-names: Gallouédec
26
+ repository-code: 'https://github.com/huggingface/trl'
27
+ abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
28
+ keywords:
29
+ - rlhf
30
+ - deep-learning
31
+ - pytorch
32
+ - transformers
33
+ license: Apache-2.0
34
+ version: 0.18
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Contributor Covenant Code of Conduct
3
+
4
+ ## Our Pledge
5
+
6
+ We as members, contributors, and leaders pledge to make participation in our
7
+ community a harassment-free experience for everyone, regardless of age, body
8
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
9
+ identity and expression, level of experience, education, socio-economic status,
10
+ nationality, personal appearance, race, caste, color, religion, or sexual
11
+ identity and orientation.
12
+
13
+ We pledge to act and interact in ways that contribute to an open, welcoming,
14
+ diverse, inclusive, and healthy community.
15
+
16
+ ## Our Standards
17
+
18
+ Examples of behavior that contributes to a positive environment for our
19
+ community include:
20
+
21
+ * Demonstrating empathy and kindness toward other people
22
+ * Being respectful of differing opinions, viewpoints, and experiences
23
+ * Giving and gracefully accepting constructive feedback
24
+ * Accepting responsibility and apologizing to those affected by our mistakes,
25
+ and learning from the experience
26
+ * Focusing on what is best not just for us as individuals, but for the overall
27
+ community
28
+
29
+ Examples of unacceptable behavior include:
30
+
31
+ * The use of sexualized language or imagery, and sexual attention or advances of
32
+ any kind
33
+ * Trolling, insulting or derogatory comments, and personal or political attacks
34
+ * Public or private harassment
35
+ * Publishing others' private information, such as a physical or email address,
36
+ without their explicit permission
37
+ * Other conduct which could reasonably be considered inappropriate in a
38
+ professional setting
39
+
40
+ ## Enforcement Responsibilities
41
+
42
+ Community leaders are responsible for clarifying and enforcing our standards of
43
+ acceptable behavior and will take appropriate and fair corrective action in
44
+ response to any behavior that they deem inappropriate, threatening, offensive,
45
+ or harmful.
46
+
47
+ Community leaders have the right and responsibility to remove, edit, or reject
48
+ comments, commits, code, wiki edits, issues, and other contributions that are
49
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
50
+ decisions when appropriate.
51
+
52
+ ## Scope
53
+
54
+ This Code of Conduct applies within all community spaces, and also applies when
55
+ an individual is officially representing the community in public spaces.
56
+ Examples of representing our community include using an official e-mail address,
57
+ posting via an official social media account, or acting as an appointed
58
+ representative at an online or offline event.
59
+
60
+ ## Enforcement
61
+
62
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
63
+ reported to the community leaders responsible for enforcement at
64
65
+ All complaints will be reviewed and investigated promptly and fairly.
66
+
67
+ All community leaders are obligated to respect the privacy and security of the
68
+ reporter of any incident.
69
+
70
+ ## Enforcement Guidelines
71
+
72
+ Community leaders will follow these Community Impact Guidelines in determining
73
+ the consequences for any action they deem in violation of this Code of Conduct:
74
+
75
+ ### 1. Correction
76
+
77
+ **Community Impact**: Use of inappropriate language or other behavior deemed
78
+ unprofessional or unwelcome in the community.
79
+
80
+ **Consequence**: A private, written warning from community leaders, providing
81
+ clarity around the nature of the violation and an explanation of why the
82
+ behavior was inappropriate. A public apology may be requested.
83
+
84
+ ### 2. Warning
85
+
86
+ **Community Impact**: A violation through a single incident or series of
87
+ actions.
88
+
89
+ **Consequence**: A warning with consequences for continued behavior. No
90
+ interaction with the people involved, including unsolicited interaction with
91
+ those enforcing the Code of Conduct, for a specified period of time. This
92
+ includes avoiding interactions in community spaces as well as external channels
93
+ like social media. Violating these terms may lead to a temporary or permanent
94
+ ban.
95
+
96
+ ### 3. Temporary Ban
97
+
98
+ **Community Impact**: A serious violation of community standards, including
99
+ sustained inappropriate behavior.
100
+
101
+ **Consequence**: A temporary ban from any sort of interaction or public
102
+ communication with the community for a specified period of time. No public or
103
+ private interaction with the people involved, including unsolicited interaction
104
+ with those enforcing the Code of Conduct, is allowed during this period.
105
+ Violating these terms may lead to a permanent ban.
106
+
107
+ ### 4. Permanent Ban
108
+
109
+ **Community Impact**: Demonstrating a pattern of violation of community
110
+ standards, including sustained inappropriate behavior, harassment of an
111
+ individual, or aggression toward or disparagement of classes of individuals.
112
+
113
+ **Consequence**: A permanent ban from any sort of public interaction within the
114
+ community.
115
+
116
+ ## Attribution
117
+
118
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
119
+ version 2.1, available at
120
+ [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
121
+
122
+ Community Impact Guidelines were inspired by
123
+ [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
124
+
125
+ For answers to common questions about this code of conduct, see the FAQ at
126
+ [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
127
+ [https://www.contributor-covenant.org/translations][translations].
128
+
129
+ [homepage]: https://www.contributor-covenant.org
130
+ [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
131
+ [Mozilla CoC]: https://github.com/mozilla/diversity
132
+ [FAQ]: https://www.contributor-covenant.org/faq
133
+ [translations]: https://www.contributor-covenant.org/translations
CONTRIBUTING.md ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute to TRL?
2
+
3
+ Everyone is welcome to contribute, and we value everybody's contribution. Code
4
+ contributions are not the only way to help the community. Answering questions, helping
5
+ others, and improving the documentation are also immensely valuable.
6
+
7
+ It also helps us if you spread the word! Reference the library in blog posts
8
+ about the awesome projects it made possible, shout out on Twitter every time it has
9
+ helped you, or simply ⭐️ the repository to say thank you.
10
+
11
+ However you choose to contribute, please be mindful and respect our
12
+ [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
13
+
14
+ **This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
15
+
16
+ ## Ways to contribute
17
+
18
+ There are several ways you can contribute to TRL:
19
+
20
+ * Fix outstanding issues with the existing code.
21
+ * Submit issues related to bugs or desired new features.
22
+ * Implement trainers for new post-training algorithms.
23
+ * Contribute to the examples or the documentation.
24
+
25
+ If you don't know where to start, there is a special [Good First
26
+ Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
27
+ open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
28
+
29
+ For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
30
+
31
+ > All contributions are equally valuable to the community. 🥰
32
+
33
+ Before you start contributing make sure you have installed all the dev tools:
34
+
35
+ ```bash
36
+ pip install -e .[dev]
37
+ ```
38
+
39
+ ## Fixing outstanding issues
40
+
41
+ If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
42
+
43
+ ## Submitting a bug-related issue or feature request
44
+
45
+ Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
46
+
47
+ ### Did you find a bug?
48
+
49
+ The TRL library is robust and reliable thanks to users who report the problems they encounter.
50
+
51
+ Before you report an issue, we would really appreciate it if you could **make sure the bug was not
52
+ already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
53
+
54
+ Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
55
+
56
+ * Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
57
+ * A short, self-contained, code snippet that allows us to reproduce the bug in
58
+ less than 30s.
59
+ * The *full* traceback if an exception is raised.
60
+ * Attach any other additional information, like screenshots, you think may help.
61
+
62
+ To get the OS and software versions automatically, run the following command:
63
+
64
+ ```bash
65
+ trl env
66
+ ```
67
+
68
+ ### Do you want a new feature?
69
+
70
+ If there is a new feature you'd like to see in TRL, please open an issue and describe:
71
+
72
+ 1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
73
+
74
+ Whatever it is, we'd love to hear about it!
75
+
76
+ 2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
77
+ 3. Provide a *code snippet* that demonstrates the feature's usage.
78
+ 4. If the feature is related to a paper, please include a link.
79
+
80
+ If your issue is well written we're already 80% of the way there by the time you create it.
81
+
82
+ ## Do you want to implement a new trainer?
83
+
84
+ New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
85
+
86
+ * **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
87
+ * **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
88
+
89
+ Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
90
+
91
+ If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
92
+
93
+ * A short description of the method and a link to the paper.
94
+ * Link to the implementation if it is open-sourced.
95
+ * Link to model weights trained with the method if they are available.
96
+
97
+ Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
98
+
99
+ * Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
100
+ * RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
101
+ * Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
102
+
103
+ ## Do you want to add documentation?
104
+
105
+ We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
106
+
107
+ ## Submitting a pull request (PR)
108
+
109
+ Before writing code, we strongly advise you to search through the existing PRs or
110
+ issues to make sure that nobody is already working on the same thing. If you are
111
+ unsure, it is always a good idea to open an issue to get some feedback.
112
+
113
+ You will need basic `git` proficiency to be able to contribute to
114
+ TRL. `git` is not the easiest tool to use but it has the greatest
115
+ manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
116
+ Git](https://git-scm.com/book/en/v2) is a very good reference.
117
+
118
+ Follow these steps to start contributing:
119
+
120
+ 1. Fork the [repository](https://github.com/huggingface/trl) by
121
+ clicking on the 'Fork' button on the repository's page. This creates a copy of the code
122
+ under your GitHub user account.
123
+
124
+ 2. Clone your fork to your local disk, and add the base repository as a remote. The following command
125
+ assumes you have your public SSH key uploaded to GitHub. See the following guide for more
126
+ [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
127
+
128
+ ```bash
129
+ $ git clone [email protected]:<your Github handle>/trl.git
130
+ $ cd trl
131
+ $ git remote add upstream https://github.com/huggingface/trl.git
132
+ ```
133
+
134
+ 3. Create a new branch to hold your development changes, and do this for every new PR you work on.
135
+
136
+ Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
137
+
138
+ ```bash
139
+ $ git checkout main
140
+ $ git fetch upstream
141
+ $ git merge upstream/main
142
+ ```
143
+
144
+ Once your `main` branch is synchronized, create a new branch from it:
145
+
146
+ ```bash
147
+ $ git checkout -b a-descriptive-name-for-my-changes
148
+ ```
149
+
150
+ **Do not** work on the `main` branch.
151
+
152
+ 4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
153
+
154
+ ```bash
155
+ $ pip install -e .[dev]
156
+ ```
157
+
158
+ (If TRL was already installed in the virtual environment, remove
159
+ it with `pip uninstall trl` before reinstalling it.)
160
+
161
+ Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
162
+ the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
163
+
164
+ 5. Develop the features on your branch.
165
+
166
+ As you work on the features, you should make sure that the test suite
167
+ passes. You should run the tests impacted by your changes like this (see
168
+ below an explanation regarding the environment variable):
169
+
170
+ ```bash
171
+ $ pytest tests/<TEST_TO_RUN>.py
172
+ ```
173
+
174
+ > For the following commands leveraging the `make` utility.
175
+
176
+ You can also run the full suite with the following command.
177
+
178
+ ```bash
179
+ $ make test
180
+ ```
181
+
182
+ TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
183
+
184
+ We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
185
+
186
+ To apply these checks and corrections in one step, use:
187
+
188
+ ```bash
189
+ $ make precommit
190
+ ```
191
+
192
+ This command runs the following:
193
+ - Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
194
+ - Runs additional scripts such as adding copyright information.
195
+
196
+ If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
197
+
198
+ Once you're happy with your changes, add changed files using `git add` and
199
+ make a commit with `git commit` to record your changes locally:
200
+
201
+ ```bash
202
+ $ git add modified_file.py
203
+ $ git commit
204
+ ```
205
+
206
+ Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
207
+
208
+ It is a good idea to sync your copy of the code with the original
209
+ repository regularly. This way you can quickly account for changes:
210
+
211
+ ```bash
212
+ $ git fetch upstream
213
+ $ git rebase upstream/main
214
+ ```
215
+
216
+ Push the changes to your account using:
217
+
218
+ ```bash
219
+ $ git push -u origin a-descriptive-name-for-my-changes
220
+ ```
221
+
222
+ 6. Once you are satisfied (**and the checklist below is happy too**), go to the
223
+ webpage of your fork on GitHub. Click on 'Pull request' to send your changes
224
+ to the project maintainers for review.
225
+
226
+ 7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
227
+
228
+
229
+ ### Checklist
230
+
231
+ 1. The title of your pull request should be a summary of its contribution;
232
+ 2. If your pull request addresses an issue, please mention the issue number in
233
+ the pull request description to make sure they are linked (and people
234
+ consulting the issue know you are working on it);
235
+ 3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
236
+ the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
237
+ it from PRs ready to be merged;
238
+ 4. Make sure existing tests pass;
239
+ 5. Add high-coverage tests. No quality testing = no merge.
240
+
241
+
242
+ ### Tests
243
+
244
+ An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
245
+ the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
246
+
247
+ We use `pytest` to run the tests. From the root of the
248
+ repository here's how to run tests with `pytest` for the library:
249
+
250
+ ```bash
251
+ $ python -m pytest -sv ./tests
252
+ ```
253
+
254
+ That's how `make test` is implemented (without the `pip install` line)!
255
+
256
+ You can specify a smaller set of tests to test only the feature
257
+ you're working on.
258
+
259
+ ### Default values guidelines
260
+
261
+ 1. **Use defaults when appropriate**:
262
+
263
+ Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
264
+
265
+ 2. **Prioritize proven defaults**:
266
+
267
+ Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
268
+
269
+ 3. **Ensure safety and predictability**:
270
+
271
+ Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
272
+
273
+ 4. **Balance consistency and flexibility**:
274
+
275
+ Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
276
+
277
+ 5. **Opt-in for new features**:
278
+
279
+ Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
280
+
281
+ ### Writing documentation
282
+
283
+ High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
284
+
285
+ To illustrate what good documentation looks like, here’s an example of a well-documented function:
286
+
287
+ ````python
288
+ def replicate_str(string: str, n: int, sep: str = " ") -> str:
289
+ r"""
290
+ Replicate a string `n` times with a separator.
291
+
292
+ Args:
293
+ string (`str`):
294
+ String to replicate.
295
+ n (`int`):
296
+ Number of times to replicate the string.
297
+ sep (`str`, *optional*, defaults to `" "`):
298
+ Separator to use between each replication.
299
+
300
+ Returns:
301
+ `str`: The replicated string.
302
+
303
+ Examples:
304
+ ```python
305
+ >>> replicate_str("hello", 3)
306
+ "hello hello hello"
307
+ >>> replicate_str("hello", 3, sep=", ")
308
+ "hello, hello, hello"
309
+ ```
310
+ """
311
+ return sep.join([string] * n)
312
+ ````
313
+
314
+ * **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
315
+ * **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
316
+ * **Type Annotations:**
317
+ * Always include type definitions, indicating if a parameter is optional and specifying the default value.
318
+ * Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
319
+ E.g., for arguments that can't be `None` and aren't required:
320
+
321
+ ```python
322
+ foo (`int`, *optional*, defaults to `4`):
323
+ ```
324
+
325
+ For arguments that can be `None` and are required:
326
+
327
+ ```python
328
+ foo (`Optional[int]`):
329
+ ```
330
+
331
+ for arguments that can be `None` and aren't required:
332
+
333
+ ```python
334
+ foo (`Optional[int]`, *optional*, defaults to `None`):
335
+ ```
336
+
337
+ * **String Defaults:**
338
+ * Ensured that default string values are wrapped in double quotes:
339
+
340
+ ```python
341
+ defaults to `"foo"`
342
+ ```
343
+
344
+ * **Dictionary Typing:**
345
+ * Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
346
+ * **Default Value Formatting:**
347
+ * Consistently surrounded default values with backticks for improved formatting:
348
+
349
+ ```python
350
+ defaults to `4`
351
+ ```
352
+
353
+ * **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
354
+
355
+ ```python
356
+ def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
357
+ r"""
358
+ Calculates basic statistics for a given dataset.
359
+
360
+ Args:
361
+ > Data inputs
362
+
363
+ data (`list[float]`):
364
+ A list of numerical values to analyze.
365
+
366
+ > Configuration parameters
367
+
368
+ precision (`int`, *optional*, defaults to `2`):
369
+ Number of decimal places to round the results.
370
+ include_variance (`bool`, *optional*, defaults to `False`):
371
+ Whether to include the variance of the dataset in the results.
372
+
373
+ Returns:
374
+ `dict[str, float]`:
375
+ A dictionary containing calculated statistics such as mean, median, and optionally variance.
376
+ """
377
+ ...
378
+ ```
379
+
380
+ ### Deprecation and backward compatibility
381
+
382
+ Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
383
+
384
+ When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
385
+
386
+ - **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
387
+ - **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
388
+
389
+ Example:
390
+
391
+ ```python
392
+ warnings.warn(
393
+ "The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
394
+ "Please use the `Trainer.bar` class instead.",
395
+ FutureWarning,
396
+ )
397
+ ```
398
+
399
+ The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
400
+
401
+ - **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
402
+
403
+ - **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
404
+
405
+ These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
406
+
407
+ ### Working with warnings
408
+
409
+ Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
410
+
411
+ #### Definitions
412
+
413
+ - **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
414
+ - **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
415
+
416
+ #### Choosing the right message
417
+
418
+ - **Correct → No warning**:
419
+ If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
420
+
421
+ - **Correct but deserves attention → No warning, possibly a log message**:
422
+ When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
423
+
424
+ ```python
425
+ logger.info("This is an informational message about a rare but correct operation.")
426
+ ```
427
+
428
+ - **Correct but very likely a mistake → Warning with option to disable**:
429
+ In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
430
+
431
+ ```python
432
+ def my_function(foo, bar, _warn=True):
433
+ if foo == bar:
434
+ if _warn:
435
+ warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
436
+ # Do something
437
+ ```
438
+
439
+ - **Supported but not correct → Warning**:
440
+ If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
441
+
442
+ ```python
443
+ def my_function(foo, bar):
444
+ if foo and bar:
445
+ warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
446
+ # Do something
447
+ ```
448
+
449
+ - **Not supported → Exception**:
450
+ If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
451
+
452
+ ```python
453
+ def my_function(foo, bar):
454
+ if foo and bar:
455
+ raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
456
+ ```
457
+
458
+ By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
459
+
460
+
461
+ ## Making a release
462
+
463
+ > [!NOTE]
464
+ > VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
465
+
466
+ #### 0. Prerequisites
467
+
468
+ - Dependencies:
469
+ - twine: `pip install build twine`
470
+ - Create an account in (and join the `trl` project):
471
+ - PyPI: https://pypi.org/
472
+ - Test PyPI: https://test.pypi.org/
473
+
474
+ ### Major/Minor Release
475
+
476
+ #### 1. Ensure your local repository is up to date with the upstream repository
477
+
478
+ ```bash
479
+ git checkout main
480
+ git pull origin main
481
+ ```
482
+
483
+ > [!WARNING]
484
+ > Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
485
+
486
+ #### 2. Create a release branch from main
487
+
488
+ ```bash
489
+ git checkout -b release-v{major}.{minor}
490
+ ```
491
+
492
+ #### 3. Change the version in the following files
493
+
494
+ - `.github/workflows/tests_latest.yml`:
495
+ ```diff
496
+ - with: { ref: v{major}.{minor-1}-release }
497
+ + with: { ref: v{major}.{minor}-release }
498
+ ```
499
+ - `CITATION.cff`
500
+ ```diff
501
+ - version: {major}.{minor-1}
502
+ + version: {major}.{minor}
503
+ ```
504
+ - `trl/__init__.py`
505
+ ```diff
506
+ - __version__ = "{major}.{minor}.0.dev0"
507
+ + __version__ = "{major}.{minor}.0"
508
+ ```
509
+ - `setup.cfg`
510
+ ```diff
511
+ - version = {major}.{minor}.0.dev0
512
+ + version = {major}.{minor}.0
513
+ ```
514
+
515
+ #### 4. Commit and push these changes
516
+
517
+ ```shell
518
+ git add .github/workflows/tests_latest.yml CITATION.cff trl/__init__.py setup.cfg
519
+ git commit -m 'Release: {major}.{minor}'
520
+ git push origin release-v{major}.{minor}
521
+ ```
522
+
523
+ #### 5. Create a pull request
524
+
525
+ from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
526
+
527
+ #### 6. Once the pull request is approved, merge it into `main`
528
+
529
+ #### 7. Add a tag in git to mark the release
530
+
531
+ ```shell
532
+ git checkout main
533
+ git pull origin main
534
+ git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
535
+ git push origin v{major}.{minor}.0
536
+ ```
537
+
538
+ #### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
539
+
540
+ ```shell
541
+ git checkout -b v{major}.{minor}-release
542
+ git push origin v{major}.{minor}-release
543
+ ```
544
+
545
+ This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
546
+
547
+ #### 9. Create the wheels for your release
548
+
549
+ These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
550
+
551
+ Clean previous builds:
552
+
553
+ ```shell
554
+ rm -rf build dist
555
+ ```
556
+
557
+ At the root of your repo, run
558
+
559
+ ```bash
560
+ python -m build .
561
+ ```
562
+
563
+ This will create a folders named `dist` with the new versions of your package.
564
+
565
+ #### 10. Upload the package to PyPI Test
566
+
567
+ > [!IMPORTANT]
568
+ > Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
569
+
570
+ ```shell
571
+ twine upload dist/* -r testpypi
572
+ ```
573
+
574
+ Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
575
+
576
+ ```bash
577
+ pip install -i https://test.pypi.org/simple/ trl
578
+ ```
579
+
580
+ You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
581
+
582
+ ```bash
583
+ pip install trl
584
+ pip uninstall trl
585
+ ```
586
+
587
+ (the second line will remove trl but keep all its dependencies).
588
+
589
+ Also make sure you can actually use the package! Run the following line:
590
+
591
+ ```bash
592
+ python -c "from trl import *"
593
+ ```
594
+
595
+ along with anything that tests:
596
+
597
+ - the core feature of your package
598
+ - the new features you’re adding in the release
599
+
600
+ #### 11. Publish on PyPI
601
+
602
+ > [!WARNING]
603
+ > This can't be reverted. Make sure you have tested everything before doing this step.
604
+
605
+ ```shell
606
+ twine upload dist/*
607
+ ```
608
+
609
+ #### 12. Create a GitHub Release
610
+
611
+ 1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
612
+ 2. Click **Draft a new release**.
613
+ 3. Select the `v{major}.{minor}.0` tag you just created in step 7.
614
+ 4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
615
+ 5. Click **Publish Release**.
616
+
617
+ #### 13. Bump to dev version
618
+
619
+ 1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
620
+
621
+ ```shell
622
+ git checkout -b bump-dev-version-{major}.{minor+1}
623
+ ```
624
+
625
+ 2. Change the version in the following files:
626
+ 1. `trl/__init__.py`
627
+ ```diff
628
+ - __version__ = "{major}.{minor}.0"
629
+ + __version__ = "{major}.{minor+1}.0.dev0"
630
+ ```
631
+ 2. `setup.cfg`
632
+ ```diff
633
+ - version = {major}.{minor}.0
634
+ + version = {major}.{minor+1}.0.dev0
635
+ ```
636
+
637
+ 3. Commit and push these changes
638
+
639
+ ```shell
640
+ git add trl/__init__.py setup.cfg
641
+ git commit -m '⬆️ Bump dev version'
642
+ git push origin bump-dev-version-{major}.{minor+1}
643
+ ```
644
+
645
+ 4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
646
+
647
+ 5. Once the pull request is approved, merge it into `main`.
648
+
649
+ 6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
650
+
651
+
652
+ ## Making a patch release
653
+
654
+ #### 1. Ensure your local repository is up to date with the upstream repository
655
+
656
+ ```bash
657
+ git checkout v{major}.{minor}-release
658
+ git pull origin main
659
+ ```
660
+
661
+ #### 2. Cherry-pick the changes you want to include in the patch release
662
+
663
+ ```bash
664
+ git cherry-pick <commit-hash-0>
665
+ git cherry-pick <commit-hash-1>
666
+ ...
667
+ ```
668
+
669
+ #### 3. Change the version in the following files
670
+
671
+ - `trl/__init__.py`
672
+ ```diff
673
+ - __version__ = "{major}.{minor}.{patch-1}"
674
+ + __version__ = "{major}.{minor}.{patch}"
675
+ ```
676
+ - `setup.cfg`
677
+ ```diff
678
+ - version = {major}.{minor}.{patch-1}
679
+ + version = {major}.{minor}.{patch}
680
+ ```
681
+
682
+ #### 4. Commit and push these changes
683
+
684
+ ```shell
685
+ git add trl/__init__.py setup.cfg
686
+ git commit -m 'Release: {major}.{minor}.{patch}'
687
+ git push origin v{major}.{minor}-release
688
+ ```
689
+
690
+ #### 5. Wait for the CI to pass
691
+
692
+ #### 6. Add a tag in git to mark the release
693
+
694
+ ```shell
695
+ git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
696
+ git push origin v{major}.{minor}.{patch}
697
+ ```
698
+
699
+ #### 7. Create the wheels for your release
700
+
701
+ These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
702
+
703
+ Clean previous builds:
704
+
705
+ ```shell
706
+ rm -rf build dist
707
+ ```
708
+
709
+ At the root of your repo, run
710
+
711
+ ```bash
712
+ python -m build .
713
+ ```
714
+
715
+ This will create a folders named `dist` with the new versions of your package.
716
+
717
+ #### 8. Upload the package to PyPI Test
718
+
719
+ > [!IMPORTANT]
720
+ > Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
721
+
722
+ ```shell
723
+ twine upload dist/* -r testpypi
724
+ ```
725
+
726
+ Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
727
+
728
+ ```bash
729
+ pip install -i https://test.pypi.org/simple/ trl
730
+ ```
731
+
732
+ You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
733
+
734
+ ```bash
735
+ pip install trl
736
+ pip uninstall trl
737
+ ```
738
+
739
+ (the second line will remove trl but keep all its dependencies).
740
+
741
+ Also make sure you can actually use the package! Run the following line:
742
+
743
+ ```bash
744
+ python -c "from trl import *"
745
+ ```
746
+
747
+ along with anything that tests:
748
+
749
+ - the core feature of your package
750
+ - the new features you’re adding in the release
751
+
752
+ #### 9. Publish on PyPI
753
+
754
+ > [!WARNING]
755
+ > This can't be reverted. Make sure you have tested everything before doing this step.
756
+
757
+ ```shell
758
+ twine upload dist/*
759
+ ```
760
+
761
+ #### 10. Create a GitHub Release
762
+
763
+ 1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
764
+ 2. Click **Draft a new release**.
765
+ 3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
766
+ 4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new.
767
+ 5. Click **Publish Release**.
Dockerfile ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/docs/hub/spaces-dev-mode#docker-spaces
2
+
3
+ FROM python:3.13-bookworm
4
+
5
+ RUN apt-get update
6
+ RUN apt-get install -y \
7
+ bash \
8
+ curl \
9
+ git \
10
+ git-lfs \
11
+ htop \
12
+ procps \
13
+ nano \
14
+ vim \
15
+ wget
16
+ RUN rm -fr /var/lib/apt/lists/*
17
+
18
+ RUN useradd -m -u 1000 user
19
+
20
+ WORKDIR /app
21
+ RUN chown user /app
22
+ RUN chmod 755 /app
23
+
24
+ USER user
25
+ ENV PATH="/home/user/.local/bin:$PATH"
26
+ RUN curl -fsSL https://pyenv.run | bash
27
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh
28
+
29
+ COPY --chown=user . /app
30
+
31
+ RUN ls -la /app
32
+
33
+ RUN uv sync
34
+
35
+ # `7860` is the default port for Hugging Face Spaces running on Docker
36
+ # https://huggingface.co/docs/hub/en/spaces-config-reference
37
+ CMD ["python", "-m", "http.server", "--directory", "public", "7860"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2020-2025 The HuggingFace Team
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
MANIFEST.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ include LICENSE
2
+ include CONTRIBUTING.md
3
+ include README.md
4
+ recursive-exclude * __pycache__
5
+ include trl/templates/*.md
6
+ include trl/accelerate_configs/*.yaml
Makefile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: test precommit common_tests slow_tests test_examples tests_gpu
2
+
3
+ check_dirs := examples tests trl
4
+
5
+ ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
6
+ COMMAND_FILES_PATH = `pwd`/commands
7
+
8
+ test:
9
+ pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
10
+
11
+ precommit:
12
+ python scripts/add_copyrights.py
13
+ pre-commit run --all-files
14
+
15
+ slow_tests:
16
+ pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
17
+
18
+ test_examples:
19
+ touch temp_results_sft_tests.txt
20
+ for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
21
+ TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
22
+ echo $$?','$${file} >> temp_results_sft_tests.txt; \
23
+ done
24
+
25
+ touch temp_results_dpo_tests.txt
26
+ for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
27
+ TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
28
+ echo $$?','$${file} >> temp_results_dpo_tests.txt; \
29
+ done
README.md ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Trl
3
+ emoji: 🚀
4
+ colorFrom: yellow
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ ---
9
+
10
+ # TRL - Transformer Reinforcement Learning
11
+
12
+ <div style="text-align: center">
13
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
14
+ </div>
15
+
16
+ <hr> <br>
17
+
18
+ <h3 align="center">
19
+ <p>A comprehensive library to post-train foundation models</p>
20
+ </h3>
21
+
22
+ <p align="center">
23
+ <a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
24
+ <a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
25
+ <a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
26
+ <a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
27
+ </p>
28
+
29
+ ## Overview
30
+
31
+ TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
32
+
33
+ ## Highlights
34
+
35
+ - **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
36
+
37
+ - **Efficient and scalable**:
38
+ - Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
39
+ - Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
40
+ - Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
41
+
42
+ - **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
43
+
44
+ ## Installation
45
+
46
+ ### Python Package
47
+
48
+ Install the library using `pip`:
49
+
50
+ ```bash
51
+ pip install trl
52
+ ```
53
+
54
+ ### From source
55
+
56
+ If you want to use the latest features before an official release, you can install TRL from source:
57
+
58
+ ```bash
59
+ pip install git+https://github.com/huggingface/trl.git
60
+ ```
61
+
62
+ ### Repository
63
+
64
+ If you want to use the examples you can clone the repository with the following command:
65
+
66
+ ```bash
67
+ git clone https://github.com/huggingface/trl.git
68
+ ```
69
+
70
+ ## Quick Start
71
+
72
+
73
+ For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
74
+
75
+ ### `SFTTrainer`
76
+
77
+ Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
78
+
79
+ ```python
80
+ from trl import SFTTrainer
81
+ from datasets import load_dataset
82
+
83
+ dataset = load_dataset("trl-lib/Capybara", split="train")
84
+
85
+ trainer = SFTTrainer(
86
+ model="Qwen/Qwen2.5-0.5B",
87
+ train_dataset=dataset,
88
+ )
89
+ trainer.train()
90
+ ```
91
+
92
+ ### `GRPOTrainer`
93
+
94
+ [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
95
+
96
+ ```python
97
+ from datasets import load_dataset
98
+ from trl import GRPOTrainer
99
+
100
+ dataset = load_dataset("trl-lib/tldr", split="train")
101
+
102
+ # Dummy reward function: count the number of unique characters in the completions
103
+ def reward_num_unique_chars(completions, **kwargs):
104
+ return [len(set(c)) for c in completions]
105
+
106
+ trainer = GRPOTrainer(
107
+ model="Qwen/Qwen2-0.5B-Instruct",
108
+ reward_funcs=reward_num_unique_chars,
109
+ train_dataset=dataset,
110
+ )
111
+ trainer.train()
112
+ ```
113
+
114
+ ### `DPOTrainer`
115
+
116
+ [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
117
+
118
+ ```python
119
+ from datasets import load_dataset
120
+ from transformers import AutoModelForCausalLM, AutoTokenizer
121
+ from trl import DPOConfig, DPOTrainer
122
+
123
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
124
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
125
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
126
+ training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
127
+ trainer = DPOTrainer(
128
+ model=model,
129
+ args=training_args,
130
+ train_dataset=dataset,
131
+ processing_class=tokenizer
132
+ )
133
+ trainer.train()
134
+ ```
135
+
136
+ ### `RewardTrainer`
137
+
138
+ Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
139
+
140
+ ```python
141
+ from trl import RewardConfig, RewardTrainer
142
+ from datasets import load_dataset
143
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
146
+ model = AutoModelForSequenceClassification.from_pretrained(
147
+ "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
148
+ )
149
+ model.config.pad_token_id = tokenizer.pad_token_id
150
+
151
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
152
+
153
+ training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
154
+ trainer = RewardTrainer(
155
+ args=training_args,
156
+ model=model,
157
+ processing_class=tokenizer,
158
+ train_dataset=dataset,
159
+ )
160
+ trainer.train()
161
+ ```
162
+
163
+ ## Command Line Interface (CLI)
164
+
165
+ You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
166
+
167
+ **SFT:**
168
+
169
+ ```bash
170
+ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
171
+ --dataset_name trl-lib/Capybara \
172
+ --output_dir Qwen2.5-0.5B-SFT
173
+ ```
174
+
175
+ **DPO:**
176
+
177
+ ```bash
178
+ trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
179
+ --dataset_name argilla/Capybara-Preferences \
180
+ --output_dir Qwen2.5-0.5B-DPO
181
+ ```
182
+
183
+ Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
184
+
185
+ ## Development
186
+
187
+ If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
188
+
189
+ ```bash
190
+ git clone https://github.com/huggingface/trl.git
191
+ cd trl/
192
+ pip install -e .[dev]
193
+ ```
194
+
195
+ ## Citation
196
+
197
+ ```bibtex
198
+ @misc{vonwerra2022trl,
199
+ author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
200
+ title = {TRL: Transformer Reinforcement Learning},
201
+ year = {2020},
202
+ publisher = {GitHub},
203
+ journal = {GitHub repository},
204
+ howpublished = {\url{https://github.com/huggingface/trl}}
205
+ }
206
+ ```
207
+
208
+ ## License
209
+
210
+ This repository's source code is available under the [Apache-2.0 License](LICENSE).
commands/run_dpo.sh ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # This script runs an SFT example end-to-end on a tiny model using different possible configurations
3
+ # but defaults to QLoRA + PEFT
4
+ OUTPUT_DIR="test_dpo/"
5
+ MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
6
+ DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
7
+ MAX_STEPS=5
8
+ BATCH_SIZE=2
9
+ SEQ_LEN=128
10
+
11
+ # Handle extra arguments in case one passes accelerate configs.
12
+ EXTRA_ACCELERATE_ARGS=""
13
+ EXTRA_TRAINING_ARGS="""--use_peft \
14
+ --load_in_4bit
15
+ """
16
+
17
+ # This is a hack to get the number of available GPUs
18
+ NUM_GPUS=2
19
+
20
+ if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
21
+ EXTRA_ACCELERATE_ARGS=""
22
+ else
23
+ EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
24
+ # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
25
+ # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
26
+ if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
27
+ EXTRA_TRAINING_ARGS="--fp16"
28
+ else
29
+ echo "Keeping QLoRA + PEFT"
30
+ fi
31
+ fi
32
+
33
+
34
+ CMD="""
35
+ accelerate launch $EXTRA_ACCELERATE_ARGS \
36
+ --num_processes $NUM_GPUS \
37
+ --mixed_precision 'fp16' \
38
+ `pwd`/trl/scripts/dpo.py \
39
+ --model_name_or_path $MODEL_NAME \
40
+ --dataset_name $DATASET_NAME \
41
+ --output_dir $OUTPUT_DIR \
42
+ --max_steps $MAX_STEPS \
43
+ --per_device_train_batch_size $BATCH_SIZE \
44
+ --max_length $SEQ_LEN \
45
+ $EXTRA_TRAINING_ARGS
46
+ """
47
+
48
+ echo "Starting program..."
49
+
50
+ { # try
51
+ echo $CMD
52
+ eval "$CMD"
53
+ } || { # catch
54
+ # save log for exception
55
+ echo "Operation Failed!"
56
+ exit 1
57
+ }
58
+ exit 0
commands/run_sft.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # This script runs an SFT example end-to-end on a tiny model using different possible configurations
3
+ # but defaults to QLoRA + PEFT
4
+ OUTPUT_DIR="test_sft/"
5
+ MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
6
+ DATASET_NAME="stanfordnlp/imdb"
7
+ MAX_STEPS=5
8
+ BATCH_SIZE=2
9
+ SEQ_LEN=128
10
+
11
+
12
+ # Handle extra arguments in case one passes accelerate configs.
13
+ EXTRA_ACCELERATE_ARGS=""
14
+ EXTRA_TRAINING_ARGS="""--use_peft \
15
+ --load_in_4bit
16
+ """
17
+
18
+ # Set your number of GPUs here
19
+ NUM_GPUS=2
20
+
21
+ if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
22
+ EXTRA_ACCELERATE_ARGS=""
23
+ else
24
+ EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
25
+ # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
26
+ # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
27
+ if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
28
+ EXTRA_TRAINING_ARGS="--fp16"
29
+ else
30
+ echo "Keeping QLoRA + PEFT"
31
+ fi
32
+ fi
33
+
34
+
35
+ CMD="""
36
+ accelerate launch $EXTRA_ACCELERATE_ARGS \
37
+ --num_processes $NUM_GPUS \
38
+ --mixed_precision 'fp16' \
39
+ `pwd`/trl/scripts/sft.py \
40
+ --model_name $MODEL_NAME \
41
+ --dataset_name $DATASET_NAME \
42
+ --output_dir $OUTPUT_DIR \
43
+ --max_steps $MAX_STEPS \
44
+ --per_device_train_batch_size $BATCH_SIZE \
45
+ --max_length $SEQ_LEN \
46
+ $EXTRA_TRAINING_ARGS
47
+ """
48
+
49
+ echo "Starting program..."
50
+
51
+ { # try
52
+ echo $CMD
53
+ eval "$CMD"
54
+ } || { # catch
55
+ # save log for exception
56
+ echo "Operation Failed!"
57
+ exit 1
58
+ }
59
+ exit 0
docker-compose.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ services:
2
+ workspace:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
docker/trl-latest-gpu/Dockerfile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Builds GPU docker image of PyTorch
2
+ # Uses multi-staged approach to reduce size
3
+ # Stage 1
4
+ # Use base conda image to reduce time
5
+ FROM continuumio/miniconda3:latest AS compile-image
6
+ # Specify py version
7
+ ENV PYTHON_VERSION=3.10
8
+ # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
9
+ RUN apt-get update && \
10
+ apt-get install -y curl git wget software-properties-common git-lfs && \
11
+ apt-get clean && \
12
+ rm -rf /var/lib/apt/lists*
13
+
14
+ # Install audio-related libraries
15
+ RUN apt-get update && \
16
+ apt install -y ffmpeg
17
+
18
+ RUN apt install -y libsndfile1-dev
19
+ RUN git lfs install
20
+
21
+ # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
22
+ RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
23
+ RUN python3 -m pip install --no-cache-dir --upgrade pip
24
+
25
+ # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
26
+ # We don't install pytorch here yet since CUDA isn't available
27
+ # instead we use the direct torch wheel
28
+ ENV PATH /opt/conda/envs/trl/bin:$PATH
29
+ # Activate our bash shell
30
+ RUN chsh -s /bin/bash
31
+ SHELL ["/bin/bash", "-c"]
32
+
33
+ # Stage 2
34
+ FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
35
+ COPY --from=compile-image /opt/conda /opt/conda
36
+ ENV PATH /opt/conda/bin:$PATH
37
+
38
+ RUN chsh -s /bin/bash
39
+ SHELL ["/bin/bash", "-c"]
40
+ RUN source activate trl && \
41
+ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
42
+
43
+ # Install apt libs
44
+ RUN apt-get update && \
45
+ apt-get install -y curl git wget && \
46
+ apt-get clean && \
47
+ rm -rf /var/lib/apt/lists*
48
+
49
+ # Activate the conda env and install transformers + accelerate from source
50
+ RUN source activate trl && \
51
+ python3 -m pip install -U --no-cache-dir \
52
+ librosa \
53
+ "soundfile>=0.12.1" \
54
+ scipy \
55
+ transformers \
56
+ accelerate \
57
+ peft \
58
+ trl[test]@git+https://github.com/huggingface/trl
59
+
60
+ RUN source activate trl && \
61
+ pip freeze | grep trl
62
+
63
+ RUN echo "source activate trl" >> ~/.profile
64
+
65
+ # Activate the virtualenv
66
+ CMD ["/bin/bash"]
docker/trl-source-gpu/Dockerfile ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Builds GPU docker image of PyTorch
2
+ # Uses multi-staged approach to reduce size
3
+ # Stage 1
4
+ # Use base conda image to reduce time
5
+ FROM continuumio/miniconda3:latest AS compile-image
6
+ # Specify py version
7
+ ENV PYTHON_VERSION=3.10
8
+ # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
9
+ RUN apt-get update && \
10
+ apt-get install -y curl git wget software-properties-common git-lfs && \
11
+ apt-get clean && \
12
+ rm -rf /var/lib/apt/lists*
13
+
14
+ # Install audio-related libraries
15
+ RUN apt-get update && \
16
+ apt install -y ffmpeg
17
+
18
+ RUN apt install -y libsndfile1-dev
19
+ RUN git lfs install
20
+
21
+ # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
22
+ RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
23
+ RUN python3 -m pip install --no-cache-dir --upgrade pip
24
+
25
+ # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
26
+ # We don't install pytorch here yet since CUDA isn't available
27
+ # instead we use the direct torch wheel
28
+ ENV PATH /opt/conda/envs/trl/bin:$PATH
29
+ # Activate our bash shell
30
+ RUN chsh -s /bin/bash
31
+ SHELL ["/bin/bash", "-c"]
32
+
33
+ # Stage 2
34
+ FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
35
+ COPY --from=compile-image /opt/conda /opt/conda
36
+ ENV PATH /opt/conda/bin:$PATH
37
+
38
+ RUN chsh -s /bin/bash
39
+ SHELL ["/bin/bash", "-c"]
40
+ RUN source activate trl && \
41
+ python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
42
+
43
+ # Install apt libs
44
+ RUN apt-get update && \
45
+ apt-get install -y curl git wget && \
46
+ apt-get clean && \
47
+ rm -rf /var/lib/apt/lists*
48
+
49
+ # Activate the conda env and install transformers + accelerate from source
50
+ RUN source activate trl && \
51
+ python3 -m pip install -U --no-cache-dir \
52
+ librosa \
53
+ "soundfile>=0.12.1" \
54
+ scipy \
55
+ git+https://github.com/huggingface/transformers \
56
+ git+https://github.com/huggingface/accelerate \
57
+ git+https://github.com/huggingface/peft \
58
+ trl[test]@git+https://github.com/huggingface/trl
59
+
60
+ RUN source activate trl && \
61
+ pip freeze | grep transformers
62
+
63
+ RUN echo "source activate trl" >> ~/.profile
64
+
65
+ # Activate the virtualenv
66
+ CMD ["/bin/bash"]
docs/source/_toctree.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ - sections:
2
+ - local: index
3
+ title: TRL
4
+ - local: installation
5
+ title: Installation
6
+ - local: quickstart
7
+ title: Quickstart
8
+ title: Getting started
9
+ - sections:
10
+ - local: dataset_formats
11
+ title: Dataset Formats
12
+ - local: how_to_train
13
+ title: Training FAQ
14
+ - local: logging
15
+ title: Understanding Logs
16
+ title: Conceptual Guides
17
+ - sections:
18
+ - local: clis
19
+ title: Command Line Interface (CLI)
20
+ - local: customization
21
+ title: Customizing the Training
22
+ - local: reducing_memory_usage
23
+ title: Reducing Memory Usage
24
+ - local: speeding_up_training
25
+ title: Speeding Up Training
26
+ - local: distributing_training
27
+ title: Distributing Training
28
+ - local: use_model
29
+ title: Using Trained Models
30
+ title: How-to guides
31
+ - sections:
32
+ - local: deepspeed_integration
33
+ title: DeepSpeed
34
+ - local: liger_kernel_integration
35
+ title: Liger Kernel
36
+ - local: peft_integration
37
+ title: PEFT
38
+ - local: unsloth_integration
39
+ title: Unsloth
40
+ - local: vllm_integration
41
+ title: vLLM
42
+ title: Integrations
43
+ - sections:
44
+ - local: example_overview
45
+ title: Example Overview
46
+ - local: community_tutorials
47
+ title: Community Tutorials
48
+ - local: sentiment_tuning
49
+ title: Sentiment Tuning
50
+ - local: using_llama_models
51
+ title: Training StackLlama
52
+ - local: detoxifying_a_lm
53
+ title: Detoxifying a Language Model
54
+ - local: multi_adapter_rl
55
+ title: Multi Adapter RLHF
56
+ - local: training_vlm_sft
57
+ title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
58
+ title: Examples
59
+ - sections:
60
+ - sections: # Sorted alphabetically
61
+ - local: alignprop_trainer
62
+ title: AlignProp
63
+ - local: bco_trainer
64
+ title: BCO
65
+ - local: cpo_trainer
66
+ title: CPO
67
+ - local: ddpo_trainer
68
+ title: DDPO
69
+ - local: dpo_trainer
70
+ title: DPO
71
+ - local: online_dpo_trainer
72
+ title: Online DPO
73
+ - local: gkd_trainer
74
+ title: GKD
75
+ - local: grpo_trainer
76
+ title: GRPO
77
+ - local: kto_trainer
78
+ title: KTO
79
+ - local: nash_md_trainer
80
+ title: Nash-MD
81
+ - local: orpo_trainer
82
+ title: ORPO
83
+ - local: ppo_trainer
84
+ title: PPO
85
+ - local: prm_trainer
86
+ title: PRM
87
+ - local: reward_trainer
88
+ title: Reward
89
+ - local: rloo_trainer
90
+ title: RLOO
91
+ - local: sft_trainer
92
+ title: SFT
93
+ - local: iterative_sft_trainer
94
+ title: Iterative SFT
95
+ - local: xpo_trainer
96
+ title: XPO
97
+ title: Trainers
98
+ - local: models
99
+ title: Model Classes
100
+ - local: model_utils
101
+ title: Model Utilities
102
+ - local: best_of_n
103
+ title: Best of N Sampling
104
+ - local: judges
105
+ title: Judges
106
+ - local: callbacks
107
+ title: Callbacks
108
+ - local: data_utils
109
+ title: Data Utilities
110
+ - local: rewards
111
+ title: Reward Functions
112
+ - local: script_utils
113
+ title: Script Utilities
114
+ - local: others
115
+ title: Others
116
+ title: API
docs/source/alignprop_trainer.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Aligning Text-to-Image Diffusion Models with Reward Backpropagation
2
+
3
+ [![](https://img.shields.io/badge/All_models-AlignProp-blue)](https://huggingface.co/models?other=alignprop,trl)
4
+
5
+ ## The why
6
+
7
+ If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
8
+ AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
9
+
10
+ <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/reward_tuning.png"/></div>
11
+
12
+
13
+ ## Getting started with `examples/scripts/alignprop.py`
14
+
15
+ The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
16
+
17
+ **Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
18
+
19
+ Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
20
+
21
+ ```batch
22
+ python alignprop.py --hf_user_access_token <token>
23
+ ```
24
+
25
+ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
26
+
27
+ The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
28
+
29
+ - The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
30
+ - The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
31
+
32
+ ## Setting up the image logging hook function
33
+
34
+ Expect the function to be given a dictionary with keys
35
+ ```python
36
+ ['image', 'prompt', 'prompt_metadata', 'rewards']
37
+
38
+ ```
39
+ and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
40
+ You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
41
+
42
+ ### Key terms
43
+
44
+ - `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
45
+ - `prompt` : The prompt is the text that is used to generate the image
46
+ - `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
47
+ - `image` : The image generated by the Stable Diffusion model
48
+
49
+ Example code for logging sampled images with `wandb` is given below.
50
+
51
+ ```python
52
+ # for logging these images to wandb
53
+
54
+ def image_outputs_hook(image_data, global_step, accelerate_logger):
55
+ # For the sake of this example, we only care about the last batch
56
+ # hence we extract the last element of the list
57
+ result = {}
58
+ images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
59
+ for i, image in enumerate(images):
60
+ pil = Image.fromarray(
61
+ (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
62
+ )
63
+ pil = pil.resize((256, 256))
64
+ result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
65
+ accelerate_logger.log_images(
66
+ result,
67
+ step=global_step,
68
+ )
69
+
70
+ ```
71
+
72
+ ### Using the finetuned model
73
+
74
+ Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
75
+
76
+ ```python
77
+ from diffusers import StableDiffusionPipeline
78
+ pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
79
+ pipeline.to("cuda")
80
+
81
+ pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
82
+
83
+ prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
84
+ results = pipeline(prompts)
85
+
86
+ for prompt, image in zip(prompts,results.images):
87
+ image.save(f"dump/{prompt}.png")
88
+ ```
89
+
90
+ ## Credits
91
+
92
+ This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
93
+ by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
docs/source/bco_trainer.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BCO Trainer
2
+
3
+ [![](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl)
4
+
5
+ TRL supports the Binary Classifier Optimization (BCO).
6
+ The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
7
+ For a full example have a look at [`examples/scripts/bco.py`].
8
+
9
+ ## Expected dataset type
10
+
11
+ The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
12
+ The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
13
+
14
+ ## Expected model format
15
+ The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
16
+
17
+ ## Using the `BCOTrainer`
18
+
19
+ For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
20
+
21
+ The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
22
+
23
+
24
+
25
+ ```py
26
+ training_args = BCOConfig(
27
+ beta=0.1,
28
+ )
29
+
30
+ bco_trainer = BCOTrainer(
31
+ model,
32
+ model_ref,
33
+ args=training_args,
34
+ train_dataset=train_dataset,
35
+ processing_class=tokenizer,
36
+ )
37
+ ```
38
+ After this one can then call:
39
+
40
+ ```py
41
+ bco_trainer.train()
42
+ ```
43
+
44
+ ## Underlying Distribution matching (UDM)
45
+
46
+ In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
47
+ Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
48
+ If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
49
+
50
+ Choose an embedding model and tokenizer:
51
+
52
+ ```py
53
+ embedding_model = AutoModel.from_pretrained(your_model_id)
54
+ embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
55
+
56
+ # customize this function depending on your embedding model
57
+ def embed_prompt(input_ids, attention_mask, model):
58
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
59
+ return outputs.last_hidden_state.mean(dim=1)
60
+
61
+ embedding_model = Accelerator().prepare_model(self.embedding_model)
62
+ embedding_func = partial(embed_prompt, model=embedding_model)
63
+ ```
64
+
65
+ Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
66
+
67
+ ```py
68
+ training_args = BCOConfig(
69
+ beta=0.1,
70
+ prompt_sample_size=512,
71
+ )
72
+
73
+ bco_trainer = BCOTrainer(
74
+ model,
75
+ model_ref,
76
+ args=training_args,
77
+ train_dataset=train_dataset,
78
+ processing_class=tokenizer,
79
+ embedding_func=embedding_func,
80
+ embedding_tokenizer=self.embedding_tokenizer,
81
+ )
82
+
83
+ bco_trainer.train()
84
+ ```
85
+
86
+ ### For Mixture of Experts Models: Enabling the auxiliary loss
87
+
88
+ MOEs are the most efficient if the load is about equally distributed between experts.
89
+ To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
90
+
91
+ This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
92
+ To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
93
+
94
+ ## BCOTrainer
95
+
96
+ [[autodoc]] BCOTrainer
97
+
98
+ ## BCOConfig
99
+
100
+ [[autodoc]] BCOConfig
docs/source/best_of_n.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
2
+
3
+ Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
4
+ As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
5
+
6
+ ## Usage
7
+
8
+ To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
9
+
10
+ ```python
11
+
12
+ from transformers import pipeline, AutoTokenizer
13
+ from trl import AutoModelForCausalLMWithValueHead
14
+ from trl.core import LengthSampler
15
+ from trl.extras import BestOfNSampler
16
+
17
+ ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
18
+ reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
19
+ tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+
22
+
23
+ # callable that takes a list of raw text and returns a list of corresponding reward scores
24
+ def queries_to_scores(list_of_strings):
25
+ return [output["score"] for output in reward_pipe(list_of_strings)]
26
+
27
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
28
+
29
+
30
+ ```
31
+
32
+ And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
33
+
34
+ ```python
35
+
36
+ best_of_n.generate(query_tensors, device=device, **gen_kwargs)
37
+
38
+ ```
39
+ The default sample size is 4, but you can change it at the time of instance initialization like so
40
+
41
+ ```python
42
+
43
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
44
+
45
+ ```
46
+
47
+ The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
48
+
49
+ ```python
50
+
51
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
52
+
53
+ ```
54
+
55
+ There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
56
+ This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
57
+
58
+ ```python
59
+
60
+ from transformers import GenerationConfig
61
+
62
+ generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
63
+
64
+ best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
65
+
66
+ best_of_n.generate(query_tensors, device=device)
67
+
68
+ ```
69
+
70
+ Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
71
+
72
+
docs/source/callbacks.md ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Callbacks
2
+
3
+ ## SyncRefModelCallback
4
+
5
+ [[autodoc]] SyncRefModelCallback
6
+
7
+ ## RichProgressCallback
8
+
9
+ [[autodoc]] RichProgressCallback
10
+
11
+ ## WinRateCallback
12
+
13
+ [[autodoc]] WinRateCallback
14
+
15
+ ## LogCompletionsCallback
16
+
17
+ [[autodoc]] LogCompletionsCallback
18
+
19
+ ## MergeModelCallback
20
+
21
+ [[autodoc]] MergeModelCallback
docs/source/clis.md ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Command Line Interfaces (CLIs)
2
+
3
+ TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
4
+
5
+ Currently supported commands are:
6
+
7
+ #### Training Commands
8
+
9
+ - `trl dpo`: fine-tune a LLM with DPO
10
+ - `trl grpo`: fine-tune a LLM with GRPO
11
+ - `trl kto`: fine-tune a LLM with KTO
12
+ - `trl sft`: fine-tune a LLM with SFT
13
+
14
+ #### Other Commands
15
+
16
+ - `trl env`: get the system information
17
+ - `trl vllm-serve`: serve a model with vLLM
18
+
19
+ ## Fine-Tuning with the TRL CLI
20
+
21
+ ### Basic Usage
22
+
23
+ You can launch training directly from the CLI by specifying required arguments like the model and dataset:
24
+
25
+ <hfoptions id="command_line">
26
+ <hfoption id="SFT">
27
+
28
+ ```bash
29
+ trl sft \
30
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
31
+ --dataset_name stanfordnlp/imdb
32
+ ```
33
+
34
+ </hfoption>
35
+ <hfoption id="DPO">
36
+
37
+ ```bash
38
+ trl dpo \
39
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
40
+ --dataset_name anthropic/hh-rlhf
41
+ ```
42
+
43
+ </hfoption>
44
+ </hfoptions>
45
+
46
+ ### Using Configuration Files
47
+
48
+ To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
49
+
50
+ <hfoptions id="config_file">
51
+ <hfoption id="SFT">
52
+
53
+ ```yaml
54
+ # sft_config.yaml
55
+ model_name_or_path: Qwen/Qwen2.5-0.5B
56
+ dataset_name: stanfordnlp/imdb
57
+ ```
58
+
59
+ Launch with:
60
+
61
+ ```bash
62
+ trl sft --config sft_config.yaml
63
+ ```
64
+
65
+ </hfoption>
66
+ <hfoption id="DPO">
67
+
68
+ ```yaml
69
+ # dpo_config.yaml
70
+ model_name_or_path: Qwen/Qwen2.5-0.5B
71
+ dataset_name: anthropic/hh-rlhf
72
+ ```
73
+
74
+ Launch with:
75
+
76
+ ```bash
77
+ trl dpo --config dpo_config.yaml
78
+ ```
79
+
80
+ </hfoption>
81
+ </hfoptions>
82
+
83
+ ### Scaling Up with Accelerate
84
+
85
+ TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
86
+
87
+ You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
88
+
89
+ <hfoptions id="launch_args">
90
+ <hfoption id="SFT inline">
91
+
92
+ ```bash
93
+ trl sft \
94
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
95
+ --dataset_name stanfordnlp/imdb \
96
+ --num_processes 4
97
+ ```
98
+
99
+ </hfoption>
100
+ <hfoption id="SFT w/ config file">
101
+
102
+ ```yaml
103
+ # sft_config.yaml
104
+ model_name_or_path: Qwen/Qwen2.5-0.5B
105
+ dataset_name: stanfordnlp/imdb
106
+ num_processes: 4
107
+ ```
108
+
109
+ Launch with:
110
+
111
+ ```bash
112
+ trl sft --config sft_config.yaml
113
+ ```
114
+
115
+ </hfoption>
116
+ <hfoption id="DPO inline">
117
+
118
+ ```bash
119
+ trl dpo \
120
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
121
+ --dataset_name anthropic/hh-rlhf \
122
+ --num_processes 4
123
+ ```
124
+
125
+ </hfoption>
126
+ <hfoption id="DPO w/ config file">
127
+
128
+ ```yaml
129
+ # dpo_config.yaml
130
+ model_name_or_path: Qwen/Qwen2.5-0.5B
131
+ dataset_name: anthropic/hh-rlhf
132
+ num_processes: 4
133
+ ```
134
+
135
+ Launch with:
136
+
137
+ ```bash
138
+ trl dpo --config dpo_config.yaml
139
+ ```
140
+ </hfoption>
141
+ </hfoptions>
142
+
143
+ ### Using `--accelerate_config` for Accelerate Configuration
144
+
145
+ The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
146
+
147
+ * the name of a predefined config profile (built into TRL), or
148
+ * a path to a custom Accelerate YAML config file.
149
+
150
+ #### Predefined Config Profiles
151
+
152
+ TRL provides several ready-to-use Accelerate configs to simplify common training setups:
153
+
154
+ | Name | Description |
155
+ | ------------ | ----------------------------------- |
156
+ | `fsdp1` | Fully Sharded Data Parallel Stage 1 |
157
+ | `fsdp2` | Fully Sharded Data Parallel Stage 2 |
158
+ | `zero1` | DeepSpeed ZeRO Stage 1 |
159
+ | `zero2` | DeepSpeed ZeRO Stage 2 |
160
+ | `zero3` | DeepSpeed ZeRO Stage 3 |
161
+ | `multi_gpu` | Multi-GPU training |
162
+ | `single_gpu` | Single-GPU training |
163
+
164
+ To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
165
+
166
+ #### Example Usage
167
+
168
+ <hfoptions id="accelerate_config">
169
+ <hfoption id="SFT inline">
170
+
171
+ ```bash
172
+ trl sft \
173
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
174
+ --dataset_name stanfordnlp/imdb \
175
+ --accelerate_config zero2 # or path/to/my/accelerate/config.yaml
176
+ ```
177
+
178
+ </hfoption>
179
+ <hfoption id="SFT w/ config file">
180
+
181
+ ```yaml
182
+ # sft_config.yaml
183
+ model_name_or_path: Qwen/Qwen2.5-0.5B
184
+ dataset_name: stanfordnlp/imdb
185
+ accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
186
+ ```
187
+
188
+ Launch with:
189
+
190
+ ```bash
191
+ trl sft --config sft_config.yaml
192
+ ```
193
+
194
+ </hfoption>
195
+ <hfoption id="DPO inline">
196
+
197
+ ```bash
198
+ trl dpo \
199
+ --model_name_or_path Qwen/Qwen2.5-0.5B \
200
+ --dataset_name anthropic/hh-rlhf \
201
+ --accelerate_config zero2 # or path/to/my/accelerate/config.yaml
202
+ ```
203
+
204
+ </hfoption>
205
+ <hfoption id="DPO w/ config file">
206
+
207
+ ```yaml
208
+ # dpo_config.yaml
209
+ model_name_or_path: Qwen/Qwen2.5-0.5B
210
+ dataset_name: anthropic/hh-rlhf
211
+ accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
212
+ ```
213
+
214
+ Launch with:
215
+
216
+ ```bash
217
+ trl dpo --config dpo_config.yaml
218
+ ```
219
+ </hfoption>
220
+ </hfoptions>
221
+
222
+ ## Getting the System Information
223
+
224
+ You can get the system information by running the following command:
225
+
226
+ ```bash
227
+ trl env
228
+ ```
229
+
230
+ This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
231
+
232
+ ```txt
233
+ Copy-paste the following information when reporting an issue:
234
+
235
+ - Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
236
+ - Python version: 3.11.9
237
+ - PyTorch version: 2.4.1
238
+ - accelerator(s): NVIDIA H100 80GB HBM3
239
+ - Transformers version: 4.45.0.dev0
240
+ - Accelerate version: 0.34.2
241
+ - Accelerate config:
242
+ - compute_environment: LOCAL_MACHINE
243
+ - distributed_type: DEEPSPEED
244
+ - mixed_precision: no
245
+ - use_cpu: False
246
+ - debug: False
247
+ - num_processes: 4
248
+ - machine_rank: 0
249
+ - num_machines: 1
250
+ - rdzv_backend: static
251
+ - same_network: True
252
+ - main_training_function: main
253
+ - enable_cpu_affinity: False
254
+ - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
255
+ - downcast_bf16: no
256
+ - tpu_use_cluster: False
257
+ - tpu_use_sudo: False
258
+ - tpu_env: []
259
+ - Datasets version: 3.0.0
260
+ - HF Hub version: 0.24.7
261
+ - TRL version: 0.12.0.dev0+acb4d70
262
+ - bitsandbytes version: 0.41.1
263
+ - DeepSpeed version: 0.15.1
264
+ - Diffusers version: 0.30.3
265
+ - Liger-Kernel version: 0.3.0
266
+ - LLM-Blender version: 0.0.2
267
+ - OpenAI version: 1.46.0
268
+ - PEFT version: 0.12.0
269
+ - vLLM version: not installed
270
+ ```
271
+
272
+ This information is required when reporting an issue.
docs/source/community_tutorials.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Community Tutorials
2
+
3
+ Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
4
+
5
+ # Language Models
6
+
7
+ | Task | Class | Description | Author | Tutorial | Colab |
8
+ | --- | --- | --- | --- | --- | --- |
9
+ | Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
10
+ | Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
11
+ | Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
12
+ | Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
13
+ | Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
14
+ | Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
15
+ | Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
16
+ | Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
17
+
18
+ <Youtube id="cnGyyM0vOes" />
19
+
20
+ # Vision Language Models
21
+
22
+ | Task | Class | Description | Author | Tutorial | Colab |
23
+ | --- | --- | --- | --- | --- | --- |
24
+ | Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
25
+ | Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
26
+ | SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
27
+ | Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
28
+ | Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
29
+
30
+ ## Contributing
31
+
32
+ If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
docs/source/cpo_trainer.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CPO Trainer
2
+
3
+ [![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl)
4
+
5
+ ## Overview
6
+
7
+ Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
8
+
9
+ CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
10
+
11
+ ## Quick start
12
+
13
+ This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
14
+
15
+ <iframe
16
+ src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
17
+ frameborder="0"
18
+ width="100%"
19
+ height="560px"
20
+ ></iframe>
21
+
22
+ Below is the script to train the model:
23
+
24
+ ```python
25
+ # train_cpo.py
26
+ from datasets import load_dataset
27
+ from trl import CPOConfig, CPOTrainer
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+
30
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
31
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
32
+ train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
33
+
34
+ training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
35
+ trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
36
+ trainer.train()
37
+ ```
38
+
39
+ Execute the script using the following command:
40
+
41
+ ```bash
42
+ accelerate launch train_cpo.py
43
+ ```
44
+
45
+ ## Expected dataset type
46
+
47
+ CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
48
+
49
+ ## Example script
50
+
51
+ We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
52
+
53
+ To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
54
+
55
+ ```bash
56
+ accelerate launch examples/scripts/cpo.py \
57
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
58
+ --dataset_name trl-lib/ultrafeedback_binarized \
59
+ --num_train_epochs 1 \
60
+ --logging_steps 25 \
61
+ --output_dir Qwen2-0.5B-CPO
62
+ ```
63
+
64
+ ## Logged metrics
65
+
66
+ While training and evaluating we record the following reward metrics:
67
+
68
+ * `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
69
+ * `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
70
+ * `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
71
+ * `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
72
+ * `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
73
+
74
+ ## CPO variants
75
+
76
+ ### Simple Preference Optimization (SimPO)
77
+
78
+ The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
79
+
80
+ ### CPO-SimPO
81
+
82
+ We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
83
+
84
+ ## Loss functions
85
+
86
+ The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
87
+
88
+ | `loss_type=` | Description |
89
+ | -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
90
+ | `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
91
+ | `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
92
+ | `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
93
+
94
+ ### For Mixture of Experts Models: Enabling the auxiliary loss
95
+
96
+ MOEs are the most efficient if the load is about equally distributed between experts.
97
+ To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
98
+
99
+ This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
100
+ To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
101
+
102
+ ## CPOTrainer
103
+
104
+ [[autodoc]] CPOTrainer
105
+
106
+ ## CPOConfig
107
+
108
+ [[autodoc]] CPOConfig
docs/source/customization.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training customization
2
+
3
+ TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
4
+
5
+
6
+
7
+ ## Use different optimizers and schedulers
8
+
9
+ By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
10
+
11
+ ```python
12
+ from datasets import load_dataset
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from torch import optim
15
+ from trl import DPOConfig, DPOTrainer
16
+
17
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
18
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
19
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
20
+ training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
21
+
22
+ optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
23
+
24
+ trainer = DPOTrainer(
25
+ model=model,
26
+ args=training_args,
27
+ train_dataset=dataset,
28
+ tokenizer=tokenizer,
29
+ optimizers=(optimizer, None),
30
+ )
31
+ trainer.train()
32
+ ```
33
+
34
+ ### Add a learning rate scheduler
35
+
36
+ You can also play with your training by adding learning rate schedulers.
37
+
38
+ ```python
39
+ from datasets import load_dataset
40
+ from transformers import AutoModelForCausalLM, AutoTokenizer
41
+ from torch import optim
42
+ from trl import DPOConfig, DPOTrainer
43
+
44
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
45
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
46
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
47
+ training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
48
+
49
+ optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
50
+ lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
51
+
52
+ trainer = DPOTrainer(
53
+ model=model,
54
+ args=training_args,
55
+ train_dataset=dataset,
56
+ tokenizer=tokenizer,
57
+ optimizers=(optimizer, lr_scheduler),
58
+ )
59
+ trainer.train()
60
+ ```
61
+
62
+ ## Memory efficient fine-tuning by sharing layers
63
+
64
+ Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
65
+
66
+ ```python
67
+ from datasets import load_dataset
68
+ from transformers import AutoModelForCausalLM, AutoTokenizer
69
+ from trl import create_reference_model, DPOConfig, DPOTrainer
70
+
71
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
72
+ ref_model = create_reference_model(model, num_shared_layers=6)
73
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
74
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
75
+ training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
76
+
77
+ trainer = DPOTrainer(
78
+ model=model,
79
+ ref_model=ref_model,
80
+ args=training_args,
81
+ train_dataset=dataset,
82
+ tokenizer=tokenizer,
83
+ )
84
+ trainer.train()
85
+ ```
86
+
87
+ ## Pass 8-bit reference models
88
+
89
+ Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
90
+
91
+ Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
92
+
93
+ ```python
94
+ from datasets import load_dataset
95
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
96
+ from trl import DPOConfig, DPOTrainer
97
+
98
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
99
+ quantization_config = BitsAndBytesConfig(load_in_8bit=True)
100
+ ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
101
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
102
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
103
+ training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
104
+
105
+ trainer = DPOTrainer(
106
+ model=model,
107
+ ref_model=ref_model,
108
+ args=training_args,
109
+ train_dataset=dataset,
110
+ tokenizer=tokenizer,
111
+ )
112
+ trainer.train()
113
+ ```
114
+
115
+ ## Use the accelerator cache optimizer
116
+
117
+ When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`:
118
+
119
+ ```python
120
+ training_args = DPOConfig(..., optimize_device_cache=True)
121
+ ```
docs/source/data_utils.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Utilities
2
+
3
+ ## is_conversational
4
+
5
+ [[autodoc]] is_conversational
6
+
7
+ ## apply_chat_template
8
+
9
+ [[autodoc]] apply_chat_template
10
+
11
+ ## maybe_apply_chat_template
12
+
13
+ [[autodoc]] maybe_apply_chat_template
14
+
15
+ ## maybe_convert_to_chatml
16
+
17
+ [[autodoc]] maybe_convert_to_chatml
18
+
19
+ ## extract_prompt
20
+
21
+ [[autodoc]] extract_prompt
22
+
23
+ ## maybe_extract_prompt
24
+
25
+ [[autodoc]] maybe_extract_prompt
26
+
27
+ ## unpair_preference_dataset
28
+
29
+ [[autodoc]] unpair_preference_dataset
30
+
31
+ ## maybe_unpair_preference_dataset
32
+
33
+ [[autodoc]] maybe_unpair_preference_dataset
34
+
35
+ ## pack_dataset
36
+
37
+ [[autodoc]] pack_dataset
38
+
39
+ ## truncate_dataset
40
+
41
+ [[autodoc]] truncate_dataset
docs/source/dataset_formats.md ADDED
@@ -0,0 +1,938 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataset formats and types
2
+
3
+ This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
4
+
5
+ ## Overview of the dataset formats and types
6
+
7
+ - The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
8
+ - The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
9
+
10
+ <table>
11
+ <tr>
12
+ <th>Type \ Format</th>
13
+ <th>Standard</th>
14
+ <th>Conversational</th>
15
+ </tr>
16
+ <tr>
17
+ <td>Language modeling</td>
18
+ <td>
19
+ <pre><code>{"text": "The sky is blue."}</code></pre>
20
+ </td>
21
+ <td>
22
+ <pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
23
+ {"role": "assistant", "content": "It is blue."}]}</code></pre>
24
+ </td>
25
+ </tr>
26
+ <tr>
27
+ <td>Prompt-only</td>
28
+ <td>
29
+ <pre><code>{"prompt": "The sky is"}</code></pre>
30
+ </td>
31
+ <td>
32
+ <pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
33
+ </td>
34
+ </tr>
35
+ <tr>
36
+ <td>Prompt-completion</td>
37
+ <td>
38
+ <pre><code>{"prompt": "The sky is",
39
+ "completion": " blue."}</code></pre>
40
+ </td>
41
+ <td>
42
+ <pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
43
+ "completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
44
+ </td>
45
+ </tr>
46
+ </tr>
47
+ <tr>
48
+ <td>Preference</td>
49
+ <td>
50
+ <pre><code>{"prompt": "The sky is",
51
+ "chosen": " blue.",
52
+ "rejected": " green."}</code></pre>
53
+ or, with implicit prompt:
54
+ <pre><code>{"chosen": "The sky is blue.",
55
+ "rejected": "The sky is green."}</code></pre>
56
+ </td>
57
+ <td>
58
+ <pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
59
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
60
+ "rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
61
+ or, with implicit prompt:
62
+ <pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
63
+ {"role": "assistant", "content": "It is blue."}],
64
+ "rejected": [{"role": "user", "content": "What color is the sky?"},
65
+ {"role": "assistant", "content": "It is green."}]}</code></pre>
66
+ </td>
67
+ </tr>
68
+ <td>Unpaired preference</td>
69
+ <td>
70
+ <pre><code>{"prompt": "The sky is",
71
+ "completion": " blue.",
72
+ "label": True}</code></pre>
73
+ </td>
74
+ <td>
75
+ <pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
76
+ "completion": [{"role": "assistant", "content": "It is green."}],
77
+ "label": False}</code></pre>
78
+ </td>
79
+ </tr>
80
+ </tr>
81
+ <td>Stepwise supervision</td>
82
+ <td>
83
+ <pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
84
+ "completions": ["The fractional part of 9.8 is 0.8.",
85
+ "The fractional part of 9.11 is 0.11.",
86
+ "0.11 is greater than 0.8.",
87
+ "Hence, 9.11 > 9.8."],
88
+ "labels": [True, True, False, False]}</code></pre>
89
+ </td>
90
+ <td></td>
91
+ </tr>
92
+ </table>
93
+
94
+ ### Formats
95
+
96
+ #### Standard
97
+
98
+ The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
99
+
100
+ ```python
101
+ # Language modeling
102
+ language_modeling_example = {"text": "The sky is blue."}
103
+ # Preference
104
+ preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
105
+ # Unpaired preference
106
+ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
107
+ ```
108
+
109
+ #### Conversational
110
+
111
+ Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
112
+
113
+ ```python
114
+ messages = [
115
+ {"role": "user", "content": "Hello, how are you?"},
116
+ {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
117
+ {"role": "user", "content": "I'd like to show off how chat templating works!"},
118
+ ]
119
+ ```
120
+
121
+ Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
122
+
123
+ ```python
124
+ # Prompt-completion
125
+ prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
126
+ "completion": [{"role": "assistant", "content": "It is blue."}]}
127
+ # Preference
128
+ preference_example = {
129
+ "prompt": [{"role": "user", "content": "What color is the sky?"}],
130
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
131
+ "rejected": [{"role": "assistant", "content": "It is green."}],
132
+ }
133
+ ```
134
+
135
+ Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
136
+
137
+ ### Types
138
+
139
+ #### Language modeling
140
+
141
+ A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
142
+
143
+ ```python
144
+ # Standard format
145
+ language_modeling_example = {"text": "The sky is blue."}
146
+ # Conversational format
147
+ language_modeling_example = {"messages": [
148
+ {"role": "user", "content": "What color is the sky?"},
149
+ {"role": "assistant", "content": "It is blue."}
150
+ ]}
151
+ ```
152
+
153
+ #### Prompt-only
154
+
155
+ In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
156
+
157
+ ```python
158
+ # Standard format
159
+ prompt_only_example = {"prompt": "The sky is"}
160
+ # Conversational format
161
+ prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
162
+ ```
163
+
164
+ For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
165
+
166
+ <Tip>
167
+
168
+ While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
169
+
170
+ ```python
171
+ from transformers import AutoTokenizer
172
+ from trl import apply_chat_template
173
+
174
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
175
+
176
+ # Example for prompt-only type
177
+ prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
178
+ apply_chat_template(prompt_only_example, tokenizer)
179
+ # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
180
+
181
+ # Example for language modeling type
182
+ lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
183
+ apply_chat_template(lm_example, tokenizer)
184
+ # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
185
+ ```
186
+
187
+ - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
188
+ - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
189
+
190
+ </Tip>
191
+
192
+ #### Prompt-completion
193
+
194
+ A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
195
+
196
+ ```python
197
+ # Standard format
198
+ prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
199
+ # Conversational format
200
+ prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
201
+ "completion": [{"role": "assistant", "content": "It is blue."}]}
202
+ ```
203
+
204
+ For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
205
+
206
+ #### Preference
207
+
208
+ A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
209
+ Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
210
+
211
+ ```python
212
+ # Standard format
213
+ ## Explicit prompt (recommended)
214
+ preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
215
+ # Implicit prompt
216
+ preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
217
+
218
+ # Conversational format
219
+ ## Explicit prompt (recommended)
220
+ preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
221
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
222
+ "rejected": [{"role": "assistant", "content": "It is green."}]}
223
+ ## Implicit prompt
224
+ preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
225
+ {"role": "assistant", "content": "It is blue."}],
226
+ "rejected": [{"role": "user", "content": "What color is the sky?"},
227
+ {"role": "assistant", "content": "It is green."}]}
228
+ ```
229
+
230
+ For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
231
+
232
+ Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
233
+
234
+ #### Unpaired preference
235
+
236
+ An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
237
+
238
+ ```python
239
+ # Standard format
240
+ unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
241
+ # Conversational format
242
+ unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
243
+ "completion": [{"role": "assistant", "content": "It is blue."}],
244
+ "label": True}
245
+ ```
246
+
247
+ For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
248
+
249
+ #### Stepwise supervision
250
+
251
+ A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
252
+
253
+ ```python
254
+ stepwise_example = {
255
+ "prompt": "Which number is larger, 9.8 or 9.11?",
256
+ "completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
257
+ "labels": [True, False]
258
+ }
259
+ ```
260
+
261
+ For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
262
+
263
+ ## Which dataset type to use?
264
+
265
+ Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
266
+
267
+ | Trainer | Expected dataset type |
268
+ | ----------------------- | ------------------------------------------------------------------------------------------------------ |
269
+ | [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
270
+ | [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
271
+ | [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
272
+ | [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
273
+ | [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
274
+ | [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
275
+ | [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
276
+ | [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
277
+ | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
278
+ | [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
279
+ | [`PPOTrainer`] | Tokenized language modeling |
280
+ | [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
281
+ | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
282
+ | [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
283
+ | [`XPOTrainer`] | [Prompt-only](#prompt-only) |
284
+
285
+ <Tip>
286
+
287
+ TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
288
+ For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
289
+
290
+ </Tip>
291
+
292
+ ## Working with conversational datasets in TRL
293
+
294
+ Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
295
+ Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
296
+
297
+ ### Converting a conversational dataset into a standard dataset
298
+
299
+ To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
300
+
301
+ For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
302
+
303
+ In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
304
+
305
+ ```python
306
+ from transformers import AutoTokenizer
307
+ from trl import apply_chat_template
308
+
309
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
310
+
311
+ example = {
312
+ "prompt": [{"role": "user", "content": "What color is the sky?"}],
313
+ "completion": [{"role": "assistant", "content": "It is blue."}]
314
+ }
315
+
316
+ apply_chat_template(example, tokenizer)
317
+ # Output:
318
+ # {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
319
+ ```
320
+
321
+ Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
322
+
323
+ ```python
324
+ from datasets import Dataset
325
+ from trl import apply_chat_template
326
+
327
+ dataset_dict = {
328
+ "prompt": [[{"role": "user", "content": "What color is the sky?"}],
329
+ [{"role": "user", "content": "Where is the sun?"}]],
330
+ "completion": [[{"role": "assistant", "content": "It is blue."}],
331
+ [{"role": "assistant", "content": "In the sky."}]]
332
+ }
333
+
334
+ dataset = Dataset.from_dict(dataset_dict)
335
+ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
336
+ # Output:
337
+ # {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
338
+ # '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
339
+ # 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
340
+ ```
341
+
342
+ <Tip warning={true}>
343
+
344
+ We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
345
+ For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
346
+
347
+ </Tip>
348
+
349
+ <Tip warning={true}>
350
+
351
+ It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
352
+
353
+ ```python
354
+ apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
355
+ # Output:
356
+ # {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
357
+ # 'completion': 'It is blue.<|im_end|>\n'}
358
+ ```
359
+
360
+ Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
361
+
362
+ </Tip>
363
+
364
+ ## Using any dataset with TRL: preprocessing and conversion
365
+
366
+ Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
367
+
368
+ To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
369
+
370
+ ### Example: UltraFeedback dataset
371
+
372
+ Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
373
+
374
+ <iframe
375
+ src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
376
+ frameborder="0"
377
+ width="100%"
378
+ height="560px"
379
+ ></iframe>
380
+
381
+ As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
382
+
383
+ By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
384
+
385
+ ```sh
386
+ python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
387
+ ```
388
+
389
+ Once converted, the dataset will look like this:
390
+
391
+ <iframe
392
+ src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
393
+ frameborder="0"
394
+ width="100%"
395
+ height="560px"
396
+ ></iframe>
397
+
398
+ Now, you can use this dataset with TRL!
399
+
400
+ By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
401
+
402
+ ## Utilities for converting dataset types
403
+
404
+ This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
405
+
406
+ For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
407
+
408
+ | From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
409
+ | ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
410
+ | Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
411
+ | Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
412
+ | Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
413
+ | Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
414
+ | Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
415
+ | Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
416
+ | Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
417
+
418
+ ### From prompt-completion to language modeling dataset
419
+
420
+ To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
421
+
422
+ ```python
423
+ from datasets import Dataset
424
+
425
+ dataset = Dataset.from_dict({
426
+ "prompt": ["The sky is", "The sun is"],
427
+ "completion": [" blue.", " in the sky."],
428
+ })
429
+
430
+ def concat_prompt_completion(example):
431
+ return {"text": example["prompt"] + example["completion"]}
432
+
433
+ dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
434
+ ```
435
+
436
+ ```python
437
+ >>> dataset[0]
438
+ {'text': 'The sky is blue.'}
439
+ ```
440
+
441
+ ### From prompt-completion to prompt-only dataset
442
+
443
+ To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
444
+
445
+ ```python
446
+ from datasets import Dataset
447
+
448
+ dataset = Dataset.from_dict({
449
+ "prompt": ["The sky is", "The sun is"],
450
+ "completion": [" blue.", " in the sky."],
451
+ })
452
+
453
+ dataset = dataset.remove_columns("completion")
454
+ ```
455
+
456
+ ```python
457
+ >>> dataset[0]
458
+ {'prompt': 'The sky is'}
459
+ ```
460
+
461
+ ### From preference with implicit prompt to language modeling dataset
462
+
463
+ To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
464
+
465
+ ```python
466
+ from datasets import Dataset
467
+
468
+ dataset = Dataset.from_dict({
469
+ "chosen": ["The sky is blue.", "The sun is in the sky."],
470
+ "rejected": ["The sky is green.", "The sun is in the sea."],
471
+ })
472
+
473
+ dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
474
+ ```
475
+
476
+ ```python
477
+ >>> dataset[0]
478
+ {'text': 'The sky is blue.'}
479
+ ```
480
+
481
+ ### From preference with implicit prompt to prompt-completion dataset
482
+
483
+ To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
484
+
485
+ ```python
486
+ from datasets import Dataset
487
+ from trl import extract_prompt
488
+
489
+ dataset = Dataset.from_dict({
490
+ "chosen": [
491
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
492
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
493
+ ],
494
+ "rejected": [
495
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
496
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
497
+ ],
498
+ })
499
+ dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
500
+ ```
501
+
502
+ ```python
503
+ >>> dataset[0]
504
+ {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
505
+ ```
506
+
507
+ ### From preference with implicit prompt to prompt-only dataset
508
+
509
+ To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
510
+
511
+ ```python
512
+ from datasets import Dataset
513
+ from trl import extract_prompt
514
+
515
+ dataset = Dataset.from_dict({
516
+ "chosen": [
517
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
518
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
519
+ ],
520
+ "rejected": [
521
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
522
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
523
+ ],
524
+ })
525
+ dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
526
+ ```
527
+
528
+ ```python
529
+ >>> dataset[0]
530
+ {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
531
+ ```
532
+
533
+ ### From implicit to explicit prompt preference dataset
534
+
535
+ To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
536
+
537
+ ```python
538
+ from datasets import Dataset
539
+ from trl import extract_prompt
540
+
541
+ dataset = Dataset.from_dict({
542
+ "chosen": [
543
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
544
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
545
+ ],
546
+ "rejected": [
547
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
548
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
549
+ ],
550
+ })
551
+
552
+ dataset = dataset.map(extract_prompt)
553
+ ```
554
+
555
+ ```python
556
+ >>> dataset[0]
557
+ {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
558
+ 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
559
+ 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
560
+ ```
561
+
562
+ ### From preference with implicit prompt to unpaired preference dataset
563
+
564
+ To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
565
+
566
+ ```python
567
+ from datasets import Dataset
568
+ from trl import extract_prompt, unpair_preference_dataset
569
+
570
+ dataset = Dataset.from_dict({
571
+ "chosen": [
572
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
573
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
574
+ ],
575
+ "rejected": [
576
+ [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
577
+ [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
578
+ ],
579
+ })
580
+
581
+ dataset = dataset.map(extract_prompt)
582
+ dataset = unpair_preference_dataset(dataset)
583
+ ```
584
+
585
+ ```python
586
+ >>> dataset[0]
587
+ {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
588
+ 'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
589
+ 'label': True}
590
+ ```
591
+
592
+ <Tip warning={true}>
593
+
594
+ Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
595
+ Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
596
+ This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
597
+
598
+ </Tip>
599
+
600
+ ### From preference to language modeling dataset
601
+
602
+ To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
603
+
604
+ ```python
605
+ from datasets import Dataset
606
+
607
+ dataset = Dataset.from_dict({
608
+ "prompt": ["The sky is", "The sun is"],
609
+ "chosen": [" blue.", " in the sky."],
610
+ "rejected": [" green.", " in the sea."],
611
+ })
612
+
613
+ def concat_prompt_chosen(example):
614
+ return {"text": example["prompt"] + example["chosen"]}
615
+
616
+ dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
617
+ ```
618
+
619
+ ```python
620
+ >>> dataset[0]
621
+ {'text': 'The sky is blue.'}
622
+ ```
623
+
624
+ ### From preference to prompt-completion dataset
625
+
626
+ To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
627
+
628
+ ```python
629
+ from datasets import Dataset
630
+
631
+ dataset = Dataset.from_dict({
632
+ "prompt": ["The sky is", "The sun is"],
633
+ "chosen": [" blue.", " in the sky."],
634
+ "rejected": [" green.", " in the sea."],
635
+ })
636
+
637
+ dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
638
+ ```
639
+
640
+ ```python
641
+ >>> dataset[0]
642
+ {'prompt': 'The sky is', 'completion': ' blue.'}
643
+ ```
644
+
645
+ ### From preference to prompt-only dataset
646
+
647
+ To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
648
+
649
+ ```python
650
+ from datasets import Dataset
651
+
652
+ dataset = Dataset.from_dict({
653
+ "prompt": ["The sky is", "The sun is"],
654
+ "chosen": [" blue.", " in the sky."],
655
+ "rejected": [" green.", " in the sea."],
656
+ })
657
+
658
+ dataset = dataset.remove_columns(["chosen", "rejected"])
659
+ ```
660
+
661
+ ```python
662
+ >>> dataset[0]
663
+ {'prompt': 'The sky is'}
664
+ ```
665
+
666
+ ### From explicit to implicit prompt preference dataset
667
+
668
+ To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
669
+
670
+ ```python
671
+ from datasets import Dataset
672
+
673
+ dataset = Dataset.from_dict({
674
+ "prompt": [
675
+ [{"role": "user", "content": "What color is the sky?"}],
676
+ [{"role": "user", "content": "Where is the sun?"}],
677
+ ],
678
+ "chosen": [
679
+ [{"role": "assistant", "content": "It is blue."}],
680
+ [{"role": "assistant", "content": "In the sky."}],
681
+ ],
682
+ "rejected": [
683
+ [{"role": "assistant", "content": "It is green."}],
684
+ [{"role": "assistant", "content": "In the sea."}],
685
+ ],
686
+ })
687
+
688
+ def concat_prompt_to_completions(example):
689
+ return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
690
+
691
+ dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
692
+ ```
693
+
694
+ ```python
695
+ >>> dataset[0]
696
+ {'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
697
+ 'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
698
+ ```
699
+
700
+ ### From preference to unpaired preference dataset
701
+
702
+ To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
703
+
704
+ ```python
705
+ from datasets import Dataset
706
+ from trl import unpair_preference_dataset
707
+
708
+ dataset = Dataset.from_dict({
709
+ "prompt": [
710
+ [{"role": "user", "content": "What color is the sky?"}],
711
+ [{"role": "user", "content": "Where is the sun?"}],
712
+ ],
713
+ "chosen": [
714
+ [{"role": "assistant", "content": "It is blue."}],
715
+ [{"role": "assistant", "content": "In the sky."}],
716
+ ],
717
+ "rejected": [
718
+ [{"role": "assistant", "content": "It is green."}],
719
+ [{"role": "assistant", "content": "In the sea."}],
720
+ ],
721
+ })
722
+
723
+ dataset = unpair_preference_dataset(dataset)
724
+ ```
725
+
726
+ ```python
727
+ >>> dataset[0]
728
+ {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
729
+ 'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
730
+ 'label': True}
731
+ ```
732
+
733
+ <Tip warning={true}>
734
+
735
+ Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
736
+ Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
737
+ This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
738
+
739
+ </Tip>
740
+
741
+ ### From unpaired preference to language modeling dataset
742
+
743
+ To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
744
+
745
+ ```python
746
+ from datasets import Dataset
747
+
748
+ dataset = Dataset.from_dict({
749
+ "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
750
+ "completion": [" blue.", " in the sky.", " green.", " in the sea."],
751
+ "label": [True, True, False, False],
752
+ })
753
+
754
+ def concatenate_prompt_completion(example):
755
+ return {"text": example["prompt"] + example["completion"]}
756
+
757
+ dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
758
+ ```
759
+
760
+ ```python
761
+ >>> dataset[0]
762
+ {'text': 'The sky is blue.'}
763
+ ```
764
+
765
+ ### From unpaired preference to prompt-completion dataset
766
+
767
+ To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
768
+
769
+ ```python
770
+ from datasets import Dataset
771
+
772
+ dataset = Dataset.from_dict({
773
+ "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
774
+ "completion": [" blue.", " in the sky.", " green.", " in the sea."],
775
+ "label": [True, True, False, False],
776
+ })
777
+
778
+ dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
779
+ ```
780
+
781
+ ```python
782
+ >>> dataset[0]
783
+ {'prompt': 'The sky is', 'completion': ' blue.'}
784
+ ```
785
+
786
+ ### From unpaired preference to prompt-only dataset
787
+
788
+ To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
789
+
790
+ ```python
791
+ from datasets import Dataset
792
+
793
+ dataset = Dataset.from_dict({
794
+ "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
795
+ "completion": [" blue.", " in the sky.", " green.", " in the sea."],
796
+ "label": [True, True, False, False],
797
+ })
798
+
799
+ dataset = dataset.remove_columns(["completion", "label"])
800
+ ```
801
+
802
+ ```python
803
+ >>> dataset[0]
804
+ {'prompt': 'The sky is'}
805
+ ```
806
+
807
+ ### From stepwise supervision to language modeling dataset
808
+
809
+ To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
810
+
811
+ ```python
812
+ from datasets import Dataset
813
+
814
+ dataset = Dataset.from_dict({
815
+ "prompt": ["Blue light", "Water"],
816
+ "completions": [[" scatters more in the atmosphere,", " so the sky is green."],
817
+ [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
818
+ "labels": [[True, False], [True, True]],
819
+ })
820
+
821
+ def concatenate_prompt_completions(example):
822
+ completion = "".join(example["completions"])
823
+ return {"text": example["prompt"] + completion}
824
+
825
+ dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
826
+ ```
827
+
828
+ ```python
829
+ >>> dataset[0]
830
+ {'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
831
+ ```
832
+
833
+ ### From stepwise supervision to prompt completion dataset
834
+
835
+ To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
836
+
837
+ ```python
838
+ from datasets import Dataset
839
+
840
+ dataset = Dataset.from_dict({
841
+ "prompt": ["Blue light", "Water"],
842
+ "completions": [[" scatters more in the atmosphere,", " so the sky is green."],
843
+ [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
844
+ "labels": [[True, False], [True, True]],
845
+ })
846
+
847
+ def join_completions(example):
848
+ completion = "".join(example["completions"])
849
+ return {"completion": completion}
850
+
851
+ dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
852
+ ```
853
+
854
+ ```python
855
+ >>> dataset[0]
856
+ {'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
857
+ ```
858
+
859
+ ### From stepwise supervision to prompt only dataset
860
+
861
+ To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
862
+
863
+ ```python
864
+ from datasets import Dataset
865
+
866
+ dataset = Dataset.from_dict({
867
+ "prompt": ["Blue light", "Water"],
868
+ "completions": [[" scatters more in the atmosphere,", " so the sky is green."],
869
+ [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
870
+ "labels": [[True, False], [True, True]],
871
+ })
872
+
873
+ dataset = dataset.remove_columns(["completions", "labels"])
874
+ ```
875
+
876
+ ```python
877
+ >>> dataset[0]
878
+ {'prompt': 'Blue light'}
879
+ ```
880
+
881
+ ### From stepwise supervision to unpaired preference dataset
882
+
883
+ To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
884
+
885
+ The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
886
+
887
+ ```python
888
+ from datasets import Dataset
889
+
890
+ dataset = Dataset.from_dict({
891
+ "prompt": ["Blue light", "Water"],
892
+ "completions": [[" scatters more in the atmosphere,", " so the sky is green."],
893
+ [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
894
+ "labels": [[True, False], [True, True]],
895
+ })
896
+
897
+ def merge_completions_and_labels(example):
898
+ return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
899
+
900
+ dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
901
+ ```
902
+
903
+ ```python
904
+ >>> dataset[0]
905
+ {'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
906
+ ```
907
+
908
+ ## Vision datasets
909
+
910
+ Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
911
+
912
+ A conversational vision dataset differs from a standard conversational dataset in two key ways:
913
+
914
+ 1. The dataset must contain the key `images` with the image data.
915
+ 2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
916
+
917
+ Example:
918
+
919
+ ```python
920
+ # Textual dataset:
921
+ "content": "What color is the sky?"
922
+
923
+ # Vision dataset:
924
+ "content": [
925
+ {"type": "image"},
926
+ {"type": "text", "text": "What color is the sky in the image?"}
927
+ ]
928
+ ```
929
+
930
+ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
931
+
932
+ <iframe
933
+ src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
934
+ frameborder="0"
935
+ width="100%"
936
+ height="560px"
937
+ ></iframe>
938
+
docs/source/ddpo_trainer.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Denoising Diffusion Policy Optimization
2
+
3
+ [![](https://img.shields.io/badge/All_models-DDPO-blue)](https://huggingface.co/models?other=ddpo,trl)
4
+
5
+ ## The why
6
+
7
+ | Before | After DDPO finetuning |
8
+ | --- | --- |
9
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_squirrel.png"/></div> |
10
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_crab.png"/></div> |
11
+ | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_starfish.png"/></div> |
12
+
13
+
14
+ ## Getting started with Stable Diffusion finetuning with reinforcement learning
15
+
16
+ The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
17
+ library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
18
+ Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
19
+
20
+ There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
21
+ There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
22
+
23
+ The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
24
+
25
+ For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
26
+
27
+ Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
28
+
29
+ Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
30
+
31
+ ## Getting started with `examples/scripts/ddpo.py`
32
+
33
+ The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
34
+
35
+ **Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
36
+
37
+ Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
38
+
39
+ ```batch
40
+ python ddpo.py --hf_user_access_token <token>
41
+ ```
42
+
43
+ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
44
+
45
+ The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
46
+
47
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
48
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
49
+ - The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
50
+
51
+ ## Setting up the image logging hook function
52
+
53
+ Expect the function to be given a list of lists of the form
54
+ ```python
55
+ [[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
56
+
57
+ ```
58
+ and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
59
+ The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
60
+ While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
61
+
62
+ ### Key terms
63
+
64
+ - `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
65
+ - `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
66
+ - `prompt` : The prompt is the text that is used to generate the image
67
+ - `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
68
+ - `image` : The image generated by the Stable Diffusion model
69
+
70
+ Example code for logging sampled images with `wandb` is given below.
71
+
72
+ ```python
73
+ # for logging these images to wandb
74
+
75
+ def image_outputs_hook(image_data, global_step, accelerate_logger):
76
+ # For the sake of this example, we only care about the last batch
77
+ # hence we extract the last element of the list
78
+ result = {}
79
+ images, prompts, _, rewards, _ = image_data[-1]
80
+ for i, image in enumerate(images):
81
+ pil = Image.fromarray(
82
+ (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
83
+ )
84
+ pil = pil.resize((256, 256))
85
+ result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
86
+ accelerate_logger.log_images(
87
+ result,
88
+ step=global_step,
89
+ )
90
+
91
+ ```
92
+
93
+ ### Using the finetuned model
94
+
95
+ Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
96
+
97
+ ```python
98
+
99
+ import torch
100
+ from trl import DefaultDDPOStableDiffusionPipeline
101
+
102
+ pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
103
+
104
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
105
+
106
+ # memory optimization
107
+ pipeline.vae.to(device, torch.float16)
108
+ pipeline.text_encoder.to(device, torch.float16)
109
+ pipeline.unet.to(device, torch.float16)
110
+
111
+ prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
112
+ results = pipeline(prompts)
113
+
114
+ for prompt, image in zip(prompts,results.images):
115
+ image.save(f"{prompt}.png")
116
+
117
+ ```
118
+
119
+ ## Credits
120
+
121
+ This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
122
+ with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
123
+
124
+ ## DDPOTrainer
125
+
126
+ [[autodoc]] DDPOTrainer
127
+
128
+ ## DDPOConfig
129
+
130
+ [[autodoc]] DDPOConfig
131
+
docs/source/deepspeed_integration.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeepSpeed Integration
2
+
3
+ <Tip warning={true}>
4
+
5
+ Section under construction. Feel free to contribute!
6
+
7
+ </Tip>
8
+
9
+ TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
10
+
11
+ DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
12
+
13
+ ![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png)
14
+
15
+ ## Installation
16
+
17
+ To use DeepSpeed with TRL, install it using the following command:
18
+
19
+ ```bash
20
+ pip install deepspeed
21
+ ```
22
+
23
+ ## Running Training Scripts with DeepSpeed
24
+
25
+ No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
26
+
27
+ ```bash
28
+ accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
29
+ ```
30
+
31
+ We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
32
+
33
+ ```bash
34
+ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
35
+ ```
36
+
37
+ ## Additional Resources
38
+
39
+ Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
docs/source/detoxifying_a_lm.md ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Detoxifying a Language Model using PPO
2
+
3
+ Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
4
+
5
+ Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
6
+
7
+ Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
8
+
9
+ | File | Description | Colab link |
10
+ |---|---| --- |
11
+ | [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
12
+ | [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
13
+ | [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
14
+
15
+ ## Context
16
+
17
+ Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
18
+
19
+ ### Computing toxicity scores
20
+
21
+ In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
22
+ Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
23
+ One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
24
+
25
+ ### Selection of models
26
+
27
+ We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
28
+
29
+ * [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
30
+ * [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
31
+ * [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
32
+
33
+ For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
34
+
35
+ | Model | Mean toxicity score |
36
+ |---|---|
37
+ | `gpt2` | 0.01602 |
38
+ | `facebook/opt-350m` | 0.01628 |
39
+ | `bigscience/bloom-560m` | 0.00767 |
40
+ | `EleutherAI/gpt-neo-125M` | **0.02016** |
41
+
42
+ ## Designing the problem
43
+
44
+ When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
45
+
46
+ ### Pre-processing the dataset
47
+
48
+ The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
49
+
50
+ A `prompt` example:
51
+ ```
52
+ { "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
53
+ ```
54
+ And its `continuation` value:
55
+ ```
56
+ { "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
57
+ ```
58
+
59
+ We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
60
+ ```python
61
+ train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
62
+
63
+ def filter_fn(sample):
64
+ toxicity = sample["prompt"]["toxicity"]
65
+ return toxicity is not None and toxicity > 0.3
66
+
67
+ train_dataset = train_dataset.filter(filter_fn, batched=False)
68
+ ```
69
+
70
+ ### Reward function
71
+
72
+ The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
73
+ We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
74
+ ```python
75
+ logits = toxicity_model(**toxicity_inputs).logits.float()
76
+ rewards = (logits[:, 0]).tolist()
77
+ ```
78
+
79
+ ### Impact of input prompts length
80
+
81
+ We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
82
+ As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
83
+
84
+
85
+ <div style="text-align: center">
86
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
87
+ </div>
88
+
89
+ ### How to deal with OOM issues
90
+
91
+ Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
92
+
93
+ - Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
94
+
95
+ ```python
96
+ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
97
+ ```
98
+
99
+ and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
100
+
101
+ - Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
102
+
103
+ <div style="text-align: center">
104
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
105
+ </div>
106
+
107
+ ```python
108
+ ref_model = create_reference_model(model, num_shared_layers=6)
109
+ trainer = PPOTrainer(..., ref_model=ref_model)
110
+ ```
111
+
112
+ In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
113
+
114
+ - One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
115
+
116
+ ## Training the model!
117
+
118
+ We have decided to keep 3 models in total that correspond to our best models:
119
+
120
+ - [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
121
+ - [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
122
+ - [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
123
+
124
+ We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
125
+
126
+ <div style="text-align: center">
127
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
128
+ </div>
129
+
130
+ The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
131
+
132
+ <div style="text-align: center">
133
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
134
+ </div>
135
+
136
+ As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
137
+
138
+ Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
139
+
140
+ <div style="text-align: center">
141
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
142
+ </div>
143
+
144
+ ## Results
145
+
146
+ We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
147
+ We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
148
+
149
+ | Model | Mean toxicity score | Std toxicity score |
150
+ | --- | --- | --- |
151
+ | `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
152
+ | `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
153
+ | --- | --- | --- |
154
+ | `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
155
+ | `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
156
+ | --- | --- | --- |
157
+ | `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
158
+ | `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
159
+
160
+ <div class="column" style="text-align:center">
161
+ <figure>
162
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
163
+ <figcaption>Toxicity score with respect to the size of the model.</figcaption>
164
+ </figure>
165
+ </div>
166
+
167
+ Below are few generation examples of `gpt-j-6b-detox` model:
168
+
169
+ <div style="text-align: center">
170
+ <img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
171
+ </div>
172
+
173
+ The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
174
+
175
+ ### Discussions
176
+
177
+ The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
178
+
179
+ To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
180
+
181
+ ### Limitations
182
+
183
+ We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
184
+
185
+ ## What is next?
186
+
187
+ You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
docs/source/distributing_training.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Distributing Training
2
+
3
+ <Tip warning={true}>
4
+ Section under construction. Feel free to contribute!
5
+ </Tip>
6
+
7
+ ## Multi-GPU Training with TRL
8
+
9
+ The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
10
+
11
+ ```bash
12
+ accelerate config
13
+ ```
14
+
15
+ and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
16
+
17
+ ```bash
18
+ accelerate launch train.py
19
+ ```
20
+
21
+ We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
22
+
23
+ ```shell
24
+ accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
25
+ ```
26
+
27
+ This automatically distributes the workload across all available GPUs.
28
+
29
+ Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
30
+ - Processes its own batch of data
31
+ - Computes the loss and gradients for that batch
32
+ - Shares gradient updates across all GPUs
33
+
34
+ ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png)
35
+
36
+ The effective batch size is calculated as:
37
+
38
+ $$
39
+ \text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
40
+ $$
41
+
42
+ To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
43
+
44
+ Example, these configurations are equivalent, and should yield the same results:
45
+
46
+ | Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
47
+ | --- | --- | --- | --- |
48
+ | 1 | 32 | 1 | Possibly high memory usage, but faster training |
49
+ | 1 | 4 | 8 | Lower memory usage, slower training |
50
+ | 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
51
+
52
+ <Tip>
53
+
54
+ Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
55
+
56
+ </Tip>
57
+
58
+ ## Multi-Nodes Training
59
+
60
+ We're working on a guide for multi-node training. Stay tuned! 🚀
docs/source/dpo_trainer.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DPO Trainer
2
+
3
+ [![](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl) [![](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
4
+
5
+ ## Overview
6
+
7
+ TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
8
+
9
+ The abstract from the paper is the following:
10
+
11
+ > While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
12
+
13
+ The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
14
+
15
+ Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
16
+
17
+ 1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
18
+ 2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
19
+
20
+ This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
21
+
22
+ ![](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d)
23
+
24
+ Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
25
+
26
+ ## Quick start
27
+
28
+ This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
29
+
30
+ <iframe
31
+ src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
32
+ frameborder="0"
33
+ width="100%"
34
+ height="560px"
35
+ ></iframe>
36
+
37
+ Below is the script to train the model:
38
+
39
+ ```python
40
+ # train_dpo.py
41
+ from datasets import load_dataset
42
+ from trl import DPOConfig, DPOTrainer
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer
44
+
45
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
46
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
47
+ train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
48
+
49
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
50
+ trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
51
+ trainer.train()
52
+ ```
53
+
54
+ Execute the script using the following command:
55
+
56
+ ```bash
57
+ accelerate launch train_dpo.py
58
+ ```
59
+
60
+ Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
61
+
62
+ ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png)
63
+
64
+ To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
65
+
66
+ <pre><code>$ transformers chat trl-lib/Qwen2-0.5B-DPO
67
+ <strong><span style="color: red;">&lt;shirin_yamani&gt;:</span></strong>
68
+ What is Huggingface?
69
+
70
+ <strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-DPO&gt;:</span></strong>
71
+ Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
72
+ </code></pre>
73
+
74
+ ## Expected dataset type
75
+
76
+ DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
77
+
78
+ Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
79
+
80
+ ### Special considerations for vision-language models
81
+
82
+ The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
83
+
84
+ Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
85
+
86
+ ```diff
87
+ - model = AutoModelForCausalLM.from_pretrained(model_id)
88
+ + model = AutoModelForVision2Seq.from_pretrained(model_id)
89
+
90
+ - tokenizer = AutoTokenizer.from_pretrained(model_id)
91
+ + processor = AutoProcessor.from_pretrained(model_id)
92
+
93
+ trainer = DPOTrainer(
94
+ model,
95
+ args=training_args,
96
+ train_dataset=train_dataset,
97
+ - processing_class=tokenizer,
98
+ + processing_class=processor,
99
+ )
100
+ ```
101
+
102
+ For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
103
+
104
+
105
+ ## Example script
106
+
107
+ We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
108
+
109
+ To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
110
+
111
+ ```bash
112
+ accelerate launch trl/scripts/dpo.py \
113
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
114
+ --dataset_name trl-lib/ultrafeedback_binarized \
115
+ --num_train_epochs 1 \
116
+ --logging_steps 25 \
117
+ --output_dir Qwen2-0.5B-DPO
118
+ ```
119
+
120
+ ## Logged metrics
121
+
122
+ While training and evaluating we record the following reward metrics:
123
+
124
+ - `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
125
+ - `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
126
+ - `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
127
+ - `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
128
+
129
+ ## Loss functions
130
+
131
+ The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
132
+
133
+ | `loss_type=` | Description |
134
+ | -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
135
+ | `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
136
+ | `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
137
+ | `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
138
+ | `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
139
+ | `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
140
+ | `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
141
+ | `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
142
+ | `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
143
+ | `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
144
+ | `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
145
+ | `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
146
+
147
+ ### Label smoothing
148
+
149
+ The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
150
+
151
+ ### Syncing the reference model
152
+
153
+ The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
154
+
155
+ ### RPO loss
156
+
157
+ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
158
+
159
+ ### WPO loss
160
+
161
+ The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
162
+
163
+ ### LD-DPO loss
164
+
165
+ The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
166
+
167
+ ### For Mixture of Experts Models: Enabling the auxiliary loss
168
+
169
+ MOEs are the most efficient if the load is about equally distributed between experts.
170
+ To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
171
+
172
+ This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
173
+ To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
174
+
175
+ ## Accelerate DPO fine-tuning using `unsloth`
176
+
177
+ You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
178
+
179
+ | GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
180
+ | -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
181
+ | A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
182
+ | Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
183
+
184
+ First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
185
+
186
+ ```diff
187
+ from datasets import load_dataset
188
+ from trl import DPOConfig, DPOTrainer
189
+ - from transformers import AutoModelForCausalLM, AutoTokenizer
190
+ + from unsloth import FastLanguageModel
191
+
192
+ - model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
193
+ - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
194
+ + model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
195
+ + model = FastLanguageModel.get_peft_model(model)
196
+ train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
197
+
198
+ - training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
199
+ + training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
200
+ trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
201
+ trainer.train()
202
+
203
+ ```
204
+
205
+ The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
206
+
207
+ ## Reference model considerations with PEFT
208
+
209
+ You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
210
+
211
+ 1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
212
+ 2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
213
+ 3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
214
+
215
+ ### Downsides to merging QLoRA before DPO (approach 2)
216
+
217
+ As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
218
+
219
+ However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
220
+
221
+ ### Using option 3 - load the adapter twice
222
+
223
+ To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
224
+
225
+ For example:
226
+
227
+ ```python
228
+ # Load the base model.
229
+ bnb_config = BitsAndBytesConfig(
230
+ load_in_4bit=True,
231
+ llm_int8_threshold=6.0,
232
+ llm_int8_has_fp16_weight=False,
233
+ bnb_4bit_compute_dtype=torch.bfloat16,
234
+ bnb_4bit_use_double_quant=True,
235
+ bnb_4bit_quant_type="nf4",
236
+ )
237
+ model = AutoModelForCausalLM.from_pretrained(
238
+ "mistralai/mixtral-8x7b-v0.1",
239
+ load_in_4bit=True,
240
+ quantization_config=bnb_config,
241
+ attn_implementation="flash_attention_2",
242
+ torch_dtype=torch.bfloat16,
243
+ device_map="auto",
244
+ )
245
+ model.config.use_cache = False
246
+
247
+ # Load the adapter.
248
+ model = PeftModel.from_pretrained(
249
+ model,
250
+ "/path/to/peft",
251
+ is_trainable=True,
252
+ adapter_name="train",
253
+ )
254
+ # Load the adapter a second time, with a different name, which will be our reference model.
255
+ model.load_adapter("/path/to/peft", adapter_name="reference")
256
+
257
+ # Initialize the trainer, without a ref_model param.
258
+ training_args = DPOConfig(
259
+ model_adapter_name="train",
260
+ ref_adapter_name="reference",
261
+ )
262
+ dpo_trainer = DPOTrainer(
263
+ model,
264
+ args=training_args,
265
+ ...
266
+ )
267
+ ```
268
+
269
+ ## DPOTrainer
270
+
271
+ [[autodoc]] DPOTrainer
272
+
273
+ ## DPOConfig
274
+
275
+ [[autodoc]] DPOConfig
276
+
277
+ ## DataCollatorForPreference
278
+
279
+ [[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
docs/source/example_overview.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Examples
2
+
3
+
4
+ ## Introduction
5
+
6
+ The examples should work in any of the following settings (with the same script):
7
+ - single GPU
8
+ - multi GPUS (using PyTorch distributed mode)
9
+ - multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
10
+ - fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
11
+
12
+ To run it in each of these various modes, first initialize the accelerate
13
+ configuration with `accelerate config`
14
+
15
+ **NOTE to train with a 4-bit or 8-bit model**, please run
16
+
17
+ ```bash
18
+ pip install --upgrade trl[quantization]
19
+ ```
20
+
21
+
22
+ ## Accelerate Config
23
+ For all the examples, you'll need to generate a 🤗 Accelerate config file with:
24
+
25
+ ```shell
26
+ accelerate config # will prompt you to define the training configuration
27
+ ```
28
+
29
+ Then, it is encouraged to launch jobs with `accelerate launch`!
30
+
31
+
32
+ # Maintained Examples
33
+
34
+ Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
35
+
36
+ | File | Description |
37
+ | --- | --- |
38
+ | [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
39
+ | [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
40
+ | [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
41
+ | [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
42
+ | [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
43
+ | [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
44
+ | [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
45
+ | [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
46
+ | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
47
+ | [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
48
+ | [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
49
+ | [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
50
+ | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
51
+ | [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
52
+ | [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
53
+ | [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
54
+ | [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
55
+ | [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
56
+ | [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
57
+ | [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
58
+ | [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
59
+
60
+ Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
61
+
62
+ | File | Description |
63
+ | --- | --- |
64
+ | [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
65
+ | [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
66
+ | [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
67
+
68
+
69
+ We also have some other examples that are less maintained but can be used as a reference:
70
+ 1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
71
+
72
+
73
+ ## Distributed training
74
+
75
+ All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
76
+
77
+ ```shell
78
+ accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
79
+ ```
80
+
81
+ You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
82
+
83
+ ### Distributed training with DeepSpeed
84
+
85
+ Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
86
+
87
+ ```shell
88
+ accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
89
+ ```