diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..20d7456666646824a9159821f01d9d4133c7c3fa --- /dev/null +++ b/.gitattributes @@ -0,0 +1,47 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.otf filter=lfs diff=lfs merge=lfs -text +*.eot filter=lfs diff=lfs merge=lfs -text +*.ttf filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +**/*.otf filter=lfs diff=lfs merge=lfs -text +**/*.eot filter=lfs diff=lfs merge=lfs -text +**/*.ttf filter=lfs diff=lfs merge=lfs -text +**/*.png filter=lfs diff=lfs merge=lfs -text +docs/**/*.otf filter=lfs diff=lfs merge=lfs -text +docs/**/*.eot filter=lfs diff=lfs merge=lfs -text +docs/**/*.ttf filter=lfs diff=lfs merge=lfs -text +docs/**/*.png filter=lfs diff=lfs merge=lfs -text diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml new file mode 100644 index 0000000000000000000000000000000000000000..fbf352cffccef11f4690e7b41346504f16ce8778 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -0,0 +1,67 @@ +name: "\U0001F41B Bug Report" +description: Submit a bug report to help us improve TRL +labels: [ "bug" ] +body: + - type: markdown + attributes: + value: | + Thanks for taking the time to fill out this bug report! 🤗 + + 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug) + + - type: textarea + id: reproduction + validations: + required: true + attributes: + label: Reproduction + description: | + Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. + If you have code snippets, error messages, stack traces please provide them here as well. + Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting + Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. + + value: | + ```python + from trl import ... + + ``` + + outputs: + + ``` + Traceback (most recent call last): + File "example.py", line 42, in + ... + ``` + + - type: textarea + id: system-info + attributes: + label: System Info + description: | + Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ... + You can get this information by running `trl env` in your terminal. + + placeholder: Copy-paste the output of `trl env` + validations: + required: true + + - type: checkboxes + id: terms + attributes: + label: Checklist + description: | + Before submitting, please confirm that you've completed each of the following. + If an item doesn't apply to your issue, check it anyway to show you've reviewed it. + options: + - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))" + required: true + - label: "I have included my system information" + required: true + - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" + required: true + - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" + required: true + - label: "Any traceback provided is complete" + required: true diff --git a/.github/ISSUE_TEMPLATE/feature-request.yml b/.github/ISSUE_TEMPLATE/feature-request.yml new file mode 100644 index 0000000000000000000000000000000000000000..0a593186c098ae3824ef994374686092f97ccb4a --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.yml @@ -0,0 +1,31 @@ +name: "\U0001F680 Feature request" +description: Submit a proposal/request for a new TRL feature +labels: [ "Feature request" ] +body: + - type: textarea + id: feature-request + validations: + required: true + attributes: + label: Feature request + description: | + A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. + + - type: textarea + id: motivation + validations: + required: true + attributes: + label: Motivation + description: | + Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. + + + - type: textarea + id: contribution + validations: + required: true + attributes: + label: Your contribution + description: | + Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) diff --git a/.github/ISSUE_TEMPLATE/new-trainer-addition.yml b/.github/ISSUE_TEMPLATE/new-trainer-addition.yml new file mode 100644 index 0000000000000000000000000000000000000000..ea0b5afb10ae6d7519d07ee510faf617f369048c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/new-trainer-addition.yml @@ -0,0 +1,32 @@ +name: "\U0001F31F New trainer addition" +description: Submit a proposal/request to implement a new trainer for a post-training method +labels: [ "New trainer" ] + +body: + - type: textarea + id: description-request + validations: + required: true + attributes: + label: Method description + description: | + Put any and all important information relative to the method + + - type: checkboxes + id: information-tasks + attributes: + label: Open source status + description: | + Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`. + options: + - label: "The method implementation is available" + - label: "The model weights are available" + - label: "The training datasets are available" + + - type: textarea + id: additional-info + attributes: + label: Provide useful links for the implementation + description: | + Please provide information regarding the implementation, the weights, and the authors. + Please mention the authors by @gh-username if you're aware of their usernames. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000000000000000000000000000000000000..4768848cbeb58b3d4ee90be332e86729fb094b27 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,31 @@ +# What does this PR do? + + + + + +Fixes # (issue) + + +## Before submitting +- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). +- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request), + Pull Request section? +- [ ] Was this discussed/approved via a GitHub issue? Please add a link + to it if that's the case. +- [ ] Did you make sure to update the documentation with your changes? +- [ ] Did you write any new necessary tests? + + +## Who can review? + +Anyone in the community is free to review the PR once the tests have passed. Feel free to tag +members/contributors who may be interested in your PR. \ No newline at end of file diff --git a/.github/codeql/custom-queries.qls b/.github/codeql/custom-queries.qls new file mode 100644 index 0000000000000000000000000000000000000000..81deab4a871ed3b8114eeec45a4e2edbf9204b70 --- /dev/null +++ b/.github/codeql/custom-queries.qls @@ -0,0 +1,19 @@ +import codeql + +from WorkflowString interpolation, Workflow workflow +where + interpolation.getStringValue().matches("${{ github.event.issue.title }}") or + interpolation.getStringValue().matches("${{ github.event.issue.body }}") or + interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or + interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or + interpolation.getStringValue().matches("${{ github.event.review.body }}") or + interpolation.getStringValue().matches("${{ github.event.comment.body }}") or + interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or + interpolation.getStringValue().matches("${{ github.event.head_commit.message }}") + interpolation.getStringValue().matches("${{ github.event.* }}") and + ( + step.getKey() = "run" or // Injection in run + step.getKey() = "env" or // Injection via env + step.getKey() = "with" // Injection via with + ) +select workflow, "🚨 Do not use directly as input of action" diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..d66349f6c85d01a16070c56d848ae4afb0f66cf6 --- /dev/null +++ b/.github/workflows/build_documentation.yml @@ -0,0 +1,19 @@ +name: Build documentation + +on: + push: + branches: + - main + - doc-builder* + - v*-release + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: trl + version_tag_suffix: "" + custom_container: huggingface/transformers-doc-builder + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..53134b6850015b6cbcb83cbbe49b6710f470883a --- /dev/null +++ b/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,19 @@ +name: Build PR Documentation + +on: + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + if: github.event.pull_request.draft == false + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: trl + version_tag_suffix: "" + custom_container: huggingface/transformers-doc-builder diff --git a/.github/workflows/clear_cache.yml b/.github/workflows/clear_cache.yml new file mode 100644 index 0000000000000000000000000000000000000000..b4f6681905893770be8b5f81ac4485ee4b90da86 --- /dev/null +++ b/.github/workflows/clear_cache.yml @@ -0,0 +1,33 @@ +name: "Cleanup Cache" + +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v4 + + - name: Cleanup + run: | + gh extension install actions/gh-actions-cache + + REPO=${{ github.repository }} + + echo "Fetching list of cache key" + cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) + + ## Setting this to not fail the workflow while deleting cache keys. + set +e + echo "Deleting caches..." + for cacheKey in $cacheKeysForPR + do + gh actions-cache delete $cacheKey -R $REPO --confirm + done + echo "Done" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/codeQL.yml b/.github/workflows/codeQL.yml new file mode 100644 index 0000000000000000000000000000000000000000..d8114daca54ab7eb83a416636e8b5b31a3989694 --- /dev/null +++ b/.github/workflows/codeQL.yml @@ -0,0 +1,26 @@ +name: "CodeQL Analysis - Workflows" + +on: + workflow_dispatch: + +jobs: + analyze: + name: "Analyze GitHub Workflows" + runs-on: ubuntu-latest + permissions: + security-events: write + actions: read + contents: read + + steps: + - name: "Checkout repository" + uses: actions/checkout@v4 + + - name: "Initialize CodeQL" + uses: github/codeql-action/init@v2 + with: + languages: "yaml" + queries: +security-and-quality, ./.github/codeql/custom-queries.qls + + - name: "Perform CodeQL Analysis" + uses: github/codeql-action/analyze@v2 diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 0000000000000000000000000000000000000000..c30737294082125e3864a55be0dbec4e3449c910 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,95 @@ +name: Build Docker images (scheduled) + +on: + workflow_dispatch: + workflow_call: + schedule: + - cron: "0 1 * * *" + +concurrency: + group: docker-image-builds + cancel-in-progress: false + +env: + CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }} + +jobs: + trl-latest: + name: "Latest TRL GPU" + runs-on: ubuntu-latest + steps: + - name: Cleanup disk + run: | + sudo ls -l /usr/local/lib/ + sudo ls -l /usr/share/ + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + sudo rm -rf /usr/local/lib/android + sudo rm -rf /usr/share/dotnet + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: Check out code + uses: actions/checkout@v4 + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Build and Push GPU + uses: docker/build-push-action@v4 + with: + context: ./docker/trl-latest-gpu + push: true + tags: huggingface/trl-latest-gpu + + - name: Post to Slack + if: always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: 🤗 Results of the trl-latest-gpu Docker Image build + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + trl-source: + name: "Latest TRL + HF ecosystem from source" + runs-on: ubuntu-latest + steps: + - name: Cleanup disk + run: | + sudo ls -l /usr/local/lib/ + sudo ls -l /usr/share/ + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + sudo rm -rf /usr/local/lib/android + sudo rm -rf /usr/share/dotnet + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: Check out code + uses: actions/checkout@v4 + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Build and Push GPU + uses: docker/build-push-action@v4 + with: + context: ./docker/trl-source-gpu + push: true + tags: huggingface/trl-source-gpu + + - name: Post to Slack + if: always() + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: 🤗 Results of the trl-source-gpu Docker Image build + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/.github/workflows/issue_auto_labeller.yml b/.github/workflows/issue_auto_labeller.yml new file mode 100644 index 0000000000000000000000000000000000000000..53af32b9f0f66a34a7279ba35e8cabda45128e4d --- /dev/null +++ b/.github/workflows/issue_auto_labeller.yml @@ -0,0 +1,15 @@ +name: "Hugging Face Issue Labeler" +on: + issues: + types: opened + +jobs: + triage: + runs-on: ubuntu-latest + permissions: + issues: write + steps: + - uses: actions/checkout@v3 + - uses: August-murr/auto-labeler@main + with: + hf-api-key: ${{ secrets.CI_HF_API_TOKEN }} diff --git a/.github/workflows/pr_style_bot.yml b/.github/workflows/pr_style_bot.yml new file mode 100644 index 0000000000000000000000000000000000000000..481716fa0f6833fc9b1fcbae11d4ab965992ce3f --- /dev/null +++ b/.github/workflows/pr_style_bot.yml @@ -0,0 +1,127 @@ +name: PR Style Bot + +on: + workflow_dispatch: + + +permissions: + contents: write + pull-requests: write + +jobs: + run-style-bot: + if: > + contains(github.event.comment.body, '@bot /style') && + github.event.issue.pull_request != null + runs-on: ubuntu-latest + + steps: + - name: Extract PR details + id: pr_info + uses: actions/github-script@v6 + with: + script: | + const prNumber = context.payload.issue.number; + const { data: pr } = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: prNumber + }); + + // We capture both the branch ref and the "full_name" of the head repo + // so that we can check out the correct repository & branch (including forks). + core.setOutput("prNumber", prNumber); + core.setOutput("headRef", pr.head.ref); + core.setOutput("headRepoFullName", pr.head.repo.full_name); + + - name: Check out PR branch + uses: actions/checkout@v3 + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + with: + # Instead of checking out the base repo, use the contributor's repo name + repository: ${{ env.HEADREPOFULLNAME }} + ref: ${{ env.HEADREF }} + # You may need fetch-depth: 0 for being able to push + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Debug + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + PRNUMBER: ${{ steps.pr_info.outputs.prNumber }} + run: | + echo "PR number: ${{ env.PRNUMBER }}" + echo "Head Ref: ${{ env.HEADREF }}" + echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}" + + - name: Set up Python + uses: actions/setup-python@v4 + + - name: Install dependencies + run: | + pip install ruff pre-commit + + - name: Download Makefile from main branch + run: | + curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile + + - name: Compare Makefiles + run: | + if ! diff -q main_Makefile Makefile; then + echo "Error: The Makefile has changed. Please ensure it matches the main branch." + exit 1 + fi + echo "No changes in Makefile. Proceeding..." + rm -rf main_Makefile + + - name: Run make style and make quality + run: | + make precommit || true + + - name: Commit and push changes + id: commit_and_push + env: + HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }} + HEADREF: ${{ steps.pr_info.outputs.headRef }} + PRNUMBER: ${{ steps.pr_info.outputs.prNumber }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}" + # Configure git with the Actions bot user + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + # Make sure your 'origin' remote is set to the contributor's fork + git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git" + + # If there are changes after running style/quality, commit them + if [ -n "$(git status --porcelain)" ]; then + git add . + git commit -m "Apply style fixes" + # Push to the original contributor's forked branch + git push origin HEAD:${{ env.HEADREF }} + echo "changes_pushed=true" >> $GITHUB_OUTPUT + else + echo "No changes to commit." + echo "changes_pushed=false" >> $GITHUB_OUTPUT + fi + + - name: Comment on PR with workflow run link + if: steps.commit_and_push.outputs.changes_pushed == 'true' + uses: actions/github-script@v6 + with: + script: | + const prNumber = parseInt(process.env.prNumber, 10); + const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}` + + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: `Style fixes have been applied. [View the workflow run here](${runUrl}).` + }); + env: + prNumber: ${{ steps.pr_info.outputs.prNumber }} diff --git a/.github/workflows/slow-tests.yml b/.github/workflows/slow-tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..7b6e0698f1cceb514f0df74470927fc37a317ac6 --- /dev/null +++ b/.github/workflows/slow-tests.yml @@ -0,0 +1,98 @@ +name: Slow tests (on push) + +on: + push: + branches: [ main ] + paths: + # Run only when python files are modified + - "trl/**.py" + - "examples/**.py" +env: + RUN_SLOW: "yes" + IS_GITHUB_CI: "1" + SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + +jobs: + run_all_tests_single_gpu: + strategy: + fail-fast: false + matrix: + docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"] + runs-on: + group: aws-g4dn-2xlarge + env: + CUDA_VISIBLE_DEVICES: "0" + TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}" + container: + image: ${{ matrix.docker-image-name }} + options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + - name: Pip install + run: | + source activate trl + pip install -e ".[test]" --no-deps + pip install pytest-reportlog parameterized + + - name: Run slow SFT tests on single GPU + if: always() + run: | + source activate trl + make slow_tests + + - name: Generate Report + if: always() + run: | + pip install slack_sdk tabulate + python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY + + + run_all_tests_multi_gpu: + strategy: + fail-fast: false + matrix: + docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"] + runs-on: + group: aws-g4dn-2xlarge + env: + CUDA_VISIBLE_DEVICES: "0,1" + TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}" + container: + image: ${{ matrix.docker-image-name }} + options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v4 + - name: Pip install + run: | + source activate trl + pip install -e ".[test]" --no-deps + pip install pytest-reportlog parameterized + + - name: Run slow SFT tests on Multi GPU + if: always() + run: | + source activate trl + make slow_tests + + - name: Run end-to-end examples tests on multi GPU + if: always() + run: | + source activate trl + pip install deepspeed + make test_examples + + - name: Generate Reports + if: always() + run: | + pip install slack_sdk tabulate + python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY + python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY + python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY + rm *.txt diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000000000000000000000000000000000000..1aae5eb88a3e873cdb948e66177d476dcb22ddf1 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,252 @@ +name: Tests + +on: + push: + branches: [ main ] + pull_request: + paths: + # Run only when relevant files are modified + - ".github/**.yml" + - "examples/**.py" + - "scripts/**.py" + - "tests/**.py" + - "trl/**.py" + - "setup.py" + +env: + TQDM_DISABLE: 1 + CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} + +jobs: + check_code_quality: + name: Check code quality + runs-on: ubuntu-latest + if: github.event.pull_request.draft == false + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: 3.12 + - uses: pre-commit/action@v3.0.1 + with: + extra_args: --all-files + + tests: + name: Tests + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] + fail-fast: false + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python ${{ matrix.python-version }} and latest dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_dev: + name: Tests with dev dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/accelerate.git + uv pip install -U git+https://github.com/huggingface/datasets.git + uv pip install -U git+https://github.com/huggingface/transformers.git + + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 and dev dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_wo_optional_deps: + name: Tests without optional dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[test]" + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 without optional dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + tests_min_versions: + name: Tests with minimum versions + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + if: github.event.pull_request.draft == false + steps: + - name: Git checkout + uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install accelerate==1.4.0 + uv pip install datasets==3.0.0 + uv pip install transformers==4.51.0 + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + if: github.ref == 'refs/heads/main' && always() # Check if the branch is main + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results with Python 3.12 and minimum dependencies versions + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/tests_latest.yml b/.github/workflows/tests_latest.yml new file mode 100644 index 0000000000000000000000000000000000000000..569078a7a99145f6628fc2f41a831996d577a72a --- /dev/null +++ b/.github/workflows/tests_latest.yml @@ -0,0 +1,66 @@ +name: Tests latest TRL release with dev dependencies + +on: + schedule: + - cron: '0 0 * * *' # Runs daily at midnight UTC + + workflow_dispatch: + +env: + TQDM_DISABLE: 1 + CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} + +jobs: + tests: + name: Tests latest TRL release with dev dependencies + runs-on: + group: aws-g4dn-2xlarge + container: + image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel + options: --gpus all + defaults: + run: + shell: bash + steps: + - name: Git checkout + uses: actions/checkout@v4 + with: { ref: v0.18-release } + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install Make and Git + run: | + apt-get update && apt-get install -y make git curl + + - name: Install uv + run: | + curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Create Python virtual environment + run: | + uv venv + uv pip install --upgrade setuptools wheel + + - name: Install dependencies + run: | + source .venv/bin/activate + uv pip install ".[dev]" + uv pip install -U git+https://github.com/huggingface/accelerate.git + uv pip install -U git+https://github.com/huggingface/datasets.git + uv pip install -U git+https://github.com/huggingface/transformers.git + + - name: Test with pytest + run: | + source .venv/bin/activate + make test + + - name: Post to Slack + uses: huggingface/hf-workflows/.github/actions/post-slack@main + with: + slack_channel: ${{ env.CI_SLACK_CHANNEL }} + title: Results of latest TRL with Python 3.12 and dev dependencies + status: ${{ job.status }} + slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml new file mode 100644 index 0000000000000000000000000000000000000000..5595aee2aa4f345c41f064bcf17b97a7a2e5c9a1 --- /dev/null +++ b/.github/workflows/trufflehog.yml @@ -0,0 +1,18 @@ +on: + push: + +name: Secret Leaks + +jobs: + trufflehog: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Secret Scanning + uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d + with: + # exclude buggy postgres detector that is causing false positives and not relevant to our codebase + extra_args: --results=verified,unknown --exclude-detectors=postgres diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 0000000000000000000000000000000000000000..2ad2ba0e8de52699f60c2da7792dab742dd6f200 --- /dev/null +++ b/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: + workflow_run: + workflows: ["Build PR Documentation"] + types: + - completed + +jobs: + build: + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: trl + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..96d5bda2e202649fa737848ce2c5edfb4d3de278 --- /dev/null +++ b/.gitignore @@ -0,0 +1,144 @@ +*.bak +.last_checked +.gitconfig +*.bak +*.log +*~ +~* +_tmp* +tmp* +tags + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +env/ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# dotenv +.env + +# virtualenv +.venv +venv/ +ENV/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +.vscode +*.swp + +# osx generated files +.DS_Store +.DS_Store? +.Trashes +ehthumbs.db +Thumbs.db +.idea + +# pytest +.pytest_cache + +# tools/trust-doc-nbs +docs_src/.last_checked + +# symlinks to fastai +docs_src/fastai +tools/fastai + +# link checker +checklink/cookies.txt + +# .gitconfig is now autogenerated +.gitconfig + +# wandb files +nbs/wandb/ +examples/notebooks/wandb/ +wandb/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7bd7aa1e45d20a5c9b8ae66001d1bafbbc36f55c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.10 + hooks: + - id: ruff-check + types_or: [ python, pyi ] + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi ] + + # - repo: https://github.com/codespell-project/codespell + # rev: v2.1.0 + # hooks: + # - id: codespell + # args: + # - --ignore-words-list=nd,reacher,thist,ths,magent,ba + # - --skip=docs/css/termynal.css,docs/js/termynal.js diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..3f8e21aed43bd57b67d409fcc744e50c8db6190c --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,34 @@ +cff-version: 1.2.0 +title: 'TRL: Transformer Reinforcement Learning' +message: >- + If you use this software, please cite it using the + metadata from this file. +type: software +authors: + - given-names: Leandro + family-names: von Werra + - given-names: Younes + family-names: Belkada + - given-names: Lewis + family-names: Tunstall + - given-names: Edward + family-names: Beeching + - given-names: Tristan + family-names: Thrush + - given-names: Nathan + family-names: Lambert + - given-names: Shengyi + family-names: Huang + - given-names: Kashif + family-names: Rasul + - given-names: Quentin + family-names: Gallouédec +repository-code: 'https://github.com/huggingface/trl' +abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported." +keywords: + - rlhf + - deep-learning + - pytorch + - transformers +license: Apache-2.0 +version: 0.18 diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..ef09fa1375a81440bf0733b659045453a5476c43 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,133 @@ + +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +feedback@huggingface.co. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.1, available at +[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at +[https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org +[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html +[Mozilla CoC]: https://github.com/mozilla/diversity +[FAQ]: https://www.contributor-covenant.org/faq +[translations]: https://www.contributor-covenant.org/translations \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..4e59c622b03078245a177e34efda56038af1c834 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,767 @@ +# How to contribute to TRL? + +Everyone is welcome to contribute, and we value everybody's contribution. Code +contributions are not the only way to help the community. Answering questions, helping +others, and improving the documentation are also immensely valuable. + +It also helps us if you spread the word! Reference the library in blog posts +about the awesome projects it made possible, shout out on Twitter every time it has +helped you, or simply ⭐️ the repository to say thank you. + +However you choose to contribute, please be mindful and respect our +[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md). + +**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).** + +## Ways to contribute + +There are several ways you can contribute to TRL: + +* Fix outstanding issues with the existing code. +* Submit issues related to bugs or desired new features. +* Implement trainers for new post-training algorithms. +* Contribute to the examples or the documentation. + +If you don't know where to start, there is a special [Good First +Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of +open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over. + +For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀 + +> All contributions are equally valuable to the community. 🥰 + +Before you start contributing make sure you have installed all the dev tools: + +```bash +pip install -e .[dev] +``` + +## Fixing outstanding issues + +If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request! + +## Submitting a bug-related issue or feature request + +Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback. + +### Did you find a bug? + +The TRL library is robust and reliable thanks to users who report the problems they encounter. + +Before you report an issue, we would really appreciate it if you could **make sure the bug was not +already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code. + +Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it: + +* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions. +* A short, self-contained, code snippet that allows us to reproduce the bug in + less than 30s. +* The *full* traceback if an exception is raised. +* Attach any other additional information, like screenshots, you think may help. + +To get the OS and software versions automatically, run the following command: + +```bash +trl env +``` + +### Do you want a new feature? + +If there is a new feature you'd like to see in TRL, please open an issue and describe: + +1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community? + + Whatever it is, we'd love to hear about it! + +2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you. +3. Provide a *code snippet* that demonstrates the feature's usage. +4. If the feature is related to a paper, please include a link. + +If your issue is well written we're already 80% of the way there by the time you create it. + +## Do you want to implement a new trainer? + +New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL: + +* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods. +* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM. + +Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL. + +If you want to implement a trainer for a new post-training method, first open an issue and provide the following information: + +* A short description of the method and a link to the paper. +* Link to the implementation if it is open-sourced. +* Link to model weights trained with the method if they are available. + +Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration: + +* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py) +* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py) +* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py) + +## Do you want to add documentation? + +We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested! + +## Submitting a pull request (PR) + +Before writing code, we strongly advise you to search through the existing PRs or +issues to make sure that nobody is already working on the same thing. If you are +unsure, it is always a good idea to open an issue to get some feedback. + +You will need basic `git` proficiency to be able to contribute to +TRL. `git` is not the easiest tool to use but it has the greatest +manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro +Git](https://git-scm.com/book/en/v2) is a very good reference. + +Follow these steps to start contributing: + +1. Fork the [repository](https://github.com/huggingface/trl) by + clicking on the 'Fork' button on the repository's page. This creates a copy of the code + under your GitHub user account. + +2. Clone your fork to your local disk, and add the base repository as a remote. The following command + assumes you have your public SSH key uploaded to GitHub. See the following guide for more + [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository). + + ```bash + $ git clone git@github.com:/trl.git + $ cd trl + $ git remote add upstream https://github.com/huggingface/trl.git + ``` + +3. Create a new branch to hold your development changes, and do this for every new PR you work on. + + Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)): + + ```bash + $ git checkout main + $ git fetch upstream + $ git merge upstream/main + ``` + + Once your `main` branch is synchronized, create a new branch from it: + + ```bash + $ git checkout -b a-descriptive-name-for-my-changes + ``` + + **Do not** work on the `main` branch. + +4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library: + + ```bash + $ pip install -e .[dev] + ``` + + (If TRL was already installed in the virtual environment, remove + it with `pip uninstall trl` before reinstalling it.) + + Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using + the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers). + +5. Develop the features on your branch. + + As you work on the features, you should make sure that the test suite + passes. You should run the tests impacted by your changes like this (see + below an explanation regarding the environment variable): + + ```bash + $ pytest tests/.py + ``` + + > For the following commands leveraging the `make` utility. + + You can also run the full suite with the following command. + + ```bash + $ make test + ``` + + TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks. + + We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR. + + To apply these checks and corrections in one step, use: + + ```bash + $ make precommit + ``` + + This command runs the following: + - Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools. + - Runs additional scripts such as adding copyright information. + + If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question. + + Once you're happy with your changes, add changed files using `git add` and + make a commit with `git commit` to record your changes locally: + + ```bash + $ git add modified_file.py + $ git commit + ``` + + Please write [good commit messages](https://chris.beams.io/posts/git-commit/). + + It is a good idea to sync your copy of the code with the original + repository regularly. This way you can quickly account for changes: + + ```bash + $ git fetch upstream + $ git rebase upstream/main + ``` + + Push the changes to your account using: + + ```bash + $ git push -u origin a-descriptive-name-for-my-changes + ``` + +6. Once you are satisfied (**and the checklist below is happy too**), go to the + webpage of your fork on GitHub. Click on 'Pull request' to send your changes + to the project maintainers for review. + +7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request. + + +### Checklist + +1. The title of your pull request should be a summary of its contribution; +2. If your pull request addresses an issue, please mention the issue number in + the pull request description to make sure they are linked (and people + consulting the issue know you are working on it); +3. To indicate a work in progress please prefix the title with `[WIP]`, or mark + the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate + it from PRs ready to be merged; +4. Make sure existing tests pass; +5. Add high-coverage tests. No quality testing = no merge. + + +### Tests + +An extensive test suite is included to test the library behavior and several examples. Library tests can be found in +the [tests folder](https://github.com/huggingface/trl/tree/main/tests). + +We use `pytest` to run the tests. From the root of the +repository here's how to run tests with `pytest` for the library: + +```bash +$ python -m pytest -sv ./tests +``` + +That's how `make test` is implemented (without the `pip install` line)! + +You can specify a smaller set of tests to test only the feature +you're working on. + +### Default values guidelines + +1. **Use defaults when appropriate**: + +Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should. + +2. **Prioritize proven defaults**: + +Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases. + +3. **Ensure safety and predictability**: + +Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases. + +4. **Balance consistency and flexibility**: + +Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3. + +5. **Opt-in for new features**: + +Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these. + +### Writing documentation + +High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project. + +To illustrate what good documentation looks like, here’s an example of a well-documented function: + +````python +def replicate_str(string: str, n: int, sep: str = " ") -> str: + r""" + Replicate a string `n` times with a separator. + + Args: + string (`str`): + String to replicate. + n (`int`): + Number of times to replicate the string. + sep (`str`, *optional*, defaults to `" "`): + Separator to use between each replication. + + Returns: + `str`: The replicated string. + + Examples: + ```python + >>> replicate_str("hello", 3) + "hello hello hello" + >>> replicate_str("hello", 3, sep=", ") + "hello, hello, hello" + ``` + """ + return sep.join([string] * n) +```` + +* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability. +* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate") +* **Type Annotations:** + * Always include type definitions, indicating if a parameter is optional and specifying the default value. + * Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value. + E.g., for arguments that can't be `None` and aren't required: + + ```python + foo (`int`, *optional*, defaults to `4`): + ``` + + For arguments that can be `None` and are required: + + ```python + foo (`Optional[int]`): + ``` + + for arguments that can be `None` and aren't required: + + ```python + foo (`Optional[int]`, *optional*, defaults to `None`): + ``` + +* **String Defaults:** + * Ensured that default string values are wrapped in double quotes: + + ```python + defaults to `"foo"` + ``` + +* **Dictionary Typing:** + * Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs. +* **Default Value Formatting:** + * Consistently surrounded default values with backticks for improved formatting: + + ```python + defaults to `4` + ``` + +* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability. + + ```python + def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]: + r""" + Calculates basic statistics for a given dataset. + + Args: + > Data inputs + + data (`list[float]`): + A list of numerical values to analyze. + + > Configuration parameters + + precision (`int`, *optional*, defaults to `2`): + Number of decimal places to round the results. + include_variance (`bool`, *optional*, defaults to `False`): + Whether to include the variance of the dataset in the results. + + Returns: + `dict[str, float]`: + A dictionary containing calculated statistics such as mean, median, and optionally variance. + """ + ... + ``` + +### Deprecation and backward compatibility + +Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs. + +When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include: + +- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement. +- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition. + +Example: + + ```python + warnings.warn( + "The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. " + "Please use the `Trainer.bar` class instead.", + FutureWarning, + ) + ``` + +The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes: + +- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next. + +- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning. + +These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs. + +### Working with warnings + +Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely. + +#### Definitions + +- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation. +- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future. + +#### Choosing the right message + +- **Correct → No warning**: + If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary. + +- **Correct but deserves attention → No warning, possibly a log message**: + When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example: + + ```python + logger.info("This is an informational message about a rare but correct operation.") + ``` + +- **Correct but very likely a mistake → Warning with option to disable**: + In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example: + + ```python + def my_function(foo, bar, _warn=True): + if foo == bar: + if _warn: + warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.") + # Do something + ``` + +- **Supported but not correct → Warning**: + If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example: + + ```python + def my_function(foo, bar): + if foo and bar: + warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.") + # Do something + ``` + +- **Not supported → Exception**: + If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example: + + ```python + def my_function(foo, bar): + if foo and bar: + raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.") + ``` + +By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages. + + +## Making a release + +> [!NOTE] +> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts. + +#### 0. Prerequisites + +- Dependencies: + - twine: `pip install build twine` +- Create an account in (and join the `trl` project): + - PyPI: https://pypi.org/ + - Test PyPI: https://test.pypi.org/ + +### Major/Minor Release + +#### 1. Ensure your local repository is up to date with the upstream repository + +```bash +git checkout main +git pull origin main +``` + +> [!WARNING] +> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done. + +#### 2. Create a release branch from main + +```bash +git checkout -b release-v{major}.{minor} +``` + +#### 3. Change the version in the following files + +- `.github/workflows/tests_latest.yml`: + ```diff + - with: { ref: v{major}.{minor-1}-release } + + with: { ref: v{major}.{minor}-release } + ``` +- `CITATION.cff` + ```diff + - version: {major}.{minor-1} + + version: {major}.{minor} + ``` +- `trl/__init__.py` + ```diff + - __version__ = "{major}.{minor}.0.dev0" + + __version__ = "{major}.{minor}.0" + ``` +- `setup.cfg` + ```diff + - version = {major}.{minor}.0.dev0 + + version = {major}.{minor}.0 + ``` + +#### 4. Commit and push these changes + +```shell +git add .github/workflows/tests_latest.yml CITATION.cff trl/__init__.py setup.cfg +git commit -m 'Release: {major}.{minor}' +git push origin release-v{major}.{minor} +``` + +#### 5. Create a pull request + +from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review. + +#### 6. Once the pull request is approved, merge it into `main` + +#### 7. Add a tag in git to mark the release + +```shell +git checkout main +git pull origin main +git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI' +git push origin v{major}.{minor}.0 +``` + +#### 8. Create a branch `v{major}.{minor}-release` for future patch releases. + +```shell +git checkout -b v{major}.{minor}-release +git push origin v{major}.{minor}-release +``` + +This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`. + +#### 9. Create the wheels for your release + +These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`. + +Clean previous builds: + +```shell +rm -rf build dist +``` + +At the root of your repo, run + +```bash +python -m build . +``` + +This will create a folders named `dist` with the new versions of your package. + +#### 10. Upload the package to PyPI Test + +> [!IMPORTANT] +> Do not skip this step. It is important to test the package before uploading it to the main PyPI server. + +```shell +twine upload dist/* -r testpypi +``` + +Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server. + +```bash +pip install -i https://test.pypi.org/simple/ trl +``` + +You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do: + +```bash +pip install trl +pip uninstall trl +``` + +(the second line will remove trl but keep all its dependencies). + +Also make sure you can actually use the package! Run the following line: + +```bash +python -c "from trl import *" +``` + +along with anything that tests: + +- the core feature of your package +- the new features you’re adding in the release + +#### 11. Publish on PyPI + +> [!WARNING] +> This can't be reverted. Make sure you have tested everything before doing this step. + +```shell +twine upload dist/* +``` + +#### 12. Create a GitHub Release + +1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub. +2. Click **Draft a new release**. +3. Select the `v{major}.{minor}.0` tag you just created in step 7. +4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new. +5. Click **Publish Release**. + +#### 13. Bump to dev version + +1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it. + + ```shell + git checkout -b bump-dev-version-{major}.{minor+1} + ``` + +2. Change the version in the following files: + 1. `trl/__init__.py` + ```diff + - __version__ = "{major}.{minor}.0" + + __version__ = "{major}.{minor+1}.0.dev0" + ``` + 2. `setup.cfg` + ```diff + - version = {major}.{minor}.0 + + version = {major}.{minor+1}.0.dev0 + ``` + +3. Commit and push these changes + + ```shell + git add trl/__init__.py setup.cfg + git commit -m '⬆️ Bump dev version' + git push origin bump-dev-version-{major}.{minor+1} + ``` + +4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review. + +5. Once the pull request is approved, merge it into `main`. + +6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel. + + +## Making a patch release + +#### 1. Ensure your local repository is up to date with the upstream repository + +```bash +git checkout v{major}.{minor}-release +git pull origin main +``` + +#### 2. Cherry-pick the changes you want to include in the patch release + +```bash +git cherry-pick +git cherry-pick +... +``` + +#### 3. Change the version in the following files + +- `trl/__init__.py` + ```diff + - __version__ = "{major}.{minor}.{patch-1}" + + __version__ = "{major}.{minor}.{patch}" + ``` +- `setup.cfg` + ```diff + - version = {major}.{minor}.{patch-1} + + version = {major}.{minor}.{patch} + ``` + +#### 4. Commit and push these changes + +```shell +git add trl/__init__.py setup.cfg +git commit -m 'Release: {major}.{minor}.{patch}' +git push origin v{major}.{minor}-release +``` + +#### 5. Wait for the CI to pass + +#### 6. Add a tag in git to mark the release + +```shell +git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI' +git push origin v{major}.{minor}.{patch} +``` + +#### 7. Create the wheels for your release + +These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`. + +Clean previous builds: + +```shell +rm -rf build dist +``` + +At the root of your repo, run + +```bash +python -m build . +``` + +This will create a folders named `dist` with the new versions of your package. + +#### 8. Upload the package to PyPI Test + +> [!IMPORTANT] +> Do not skip this step. It is important to test the package before uploading it to the main PyPI server. + +```shell +twine upload dist/* -r testpypi +``` + +Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server. + +```bash +pip install -i https://test.pypi.org/simple/ trl +``` + +You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do: + +```bash +pip install trl +pip uninstall trl +``` + +(the second line will remove trl but keep all its dependencies). + +Also make sure you can actually use the package! Run the following line: + +```bash +python -c "from trl import *" +``` + +along with anything that tests: + +- the core feature of your package +- the new features you’re adding in the release + +#### 9. Publish on PyPI + +> [!WARNING] +> This can't be reverted. Make sure you have tested everything before doing this step. + +```shell +twine upload dist/* +``` + +#### 10. Create a GitHub Release + +1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub. +2. Click **Draft a new release**. +3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7. +4. Add a title (`v{major}.{minor}.{patch}`) and a short description of what’s new. +5. Click **Publish Release**. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..89424f4b95f45811e4329e5e5c23e6e95a6cb6e4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,37 @@ +# https://huggingface.co/docs/hub/spaces-dev-mode#docker-spaces + +FROM python:3.13-bookworm + +RUN apt-get update +RUN apt-get install -y \ + bash \ + curl \ + git \ + git-lfs \ + htop \ + procps \ + nano \ + vim \ + wget +RUN rm -fr /var/lib/apt/lists/* + +RUN useradd -m -u 1000 user + +WORKDIR /app +RUN chown user /app +RUN chmod 755 /app + +USER user +ENV PATH="/home/user/.local/bin:$PATH" +RUN curl -fsSL https://pyenv.run | bash +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +COPY --chown=user . /app + +RUN ls -la /app + +RUN uv sync + +# `7860` is the default port for Hugging Face Spaces running on Docker +# https://huggingface.co/docs/hub/en/spaces-config-reference +CMD ["python", "-m", "http.server", "--directory", "public", "7860"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a98c551ed4f7ba78782da3c4f4c47ca6443592b9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020-2025 The HuggingFace Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..8855af1a5ae3b380543388db0565b67c25c67246 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,6 @@ +include LICENSE +include CONTRIBUTING.md +include README.md +recursive-exclude * __pycache__ +include trl/templates/*.md +include trl/accelerate_configs/*.yaml \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..8b6b53e31f8a6c0352635e57428f8a7d26419257 --- /dev/null +++ b/Makefile @@ -0,0 +1,29 @@ +.PHONY: test precommit common_tests slow_tests test_examples tests_gpu + +check_dirs := examples tests trl + +ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs +COMMAND_FILES_PATH = `pwd`/commands + +test: + pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/ + +precommit: + python scripts/add_copyrights.py + pre-commit run --all-files + +slow_tests: + pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",) + +test_examples: + touch temp_results_sft_tests.txt + for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \ + TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \ + echo $$?','$${file} >> temp_results_sft_tests.txt; \ + done + + touch temp_results_dpo_tests.txt + for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \ + TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \ + echo $$?','$${file} >> temp_results_dpo_tests.txt; \ + done diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3a517ed37c0f2a30d718287239dd4f2c5f9b1001 --- /dev/null +++ b/README.md @@ -0,0 +1,210 @@ +--- +title: Trl +emoji: 🚀 +colorFrom: yellow +colorTo: green +sdk: docker +pinned: false +--- + +# TRL - Transformer Reinforcement Learning + +
+TRL Banner +
+ +

+ +

+

A comprehensive library to post-train foundation models

+

+ +

+ License + Documentation + GitHub release + Hugging Face Hub +

+ +## Overview + +TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. + +## Highlights + +- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more. + +- **Efficient and scalable**: + - Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed). + - Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA. + - Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels. + +- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code. + +## Installation + +### Python Package + +Install the library using `pip`: + +```bash +pip install trl +``` + +### From source + +If you want to use the latest features before an official release, you can install TRL from source: + +```bash +pip install git+https://github.com/huggingface/trl.git +``` + +### Repository + +If you want to use the examples you can clone the repository with the following command: + +```bash +git clone https://github.com/huggingface/trl.git +``` + +## Quick Start + + +For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP. + +### `SFTTrainer` + +Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): + +```python +from trl import SFTTrainer +from datasets import load_dataset + +dataset = load_dataset("trl-lib/Capybara", split="train") + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-0.5B", + train_dataset=dataset, +) +trainer.train() +``` + +### `GRPOTrainer` + +[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1). + +```python +from datasets import load_dataset +from trl import GRPOTrainer + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) +trainer.train() +``` + +### `DPOTrainer` + +[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`: + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer +) +trainer.train() +``` + +### `RewardTrainer` + +Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): + +```python +from trl import RewardConfig, RewardTrainer +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +model = AutoModelForSequenceClassification.from_pretrained( + "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1 +) +model.config.pad_token_id = tokenizer.pad_token_id + +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2) +trainer = RewardTrainer( + args=training_args, + model=model, + processing_class=tokenizer, + train_dataset=dataset, +) +trainer.train() +``` + +## Command Line Interface (CLI) + +You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO): + +**SFT:** + +```bash +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name trl-lib/Capybara \ + --output_dir Qwen2.5-0.5B-SFT +``` + +**DPO:** + +```bash +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --dataset_name argilla/Capybara-Preferences \ + --output_dir Qwen2.5-0.5B-DPO +``` + +Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details. + +## Development + +If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e .[dev] +``` + +## Citation + +```bibtex +@misc{vonwerra2022trl, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, + title = {TRL: Transformer Reinforcement Learning}, + year = {2020}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/huggingface/trl}} +} +``` + +## License + +This repository's source code is available under the [Apache-2.0 License](LICENSE). diff --git a/commands/run_dpo.sh b/commands/run_dpo.sh new file mode 100644 index 0000000000000000000000000000000000000000..f34b12cbb1c79cac9323c1fe06728228d56f4db3 --- /dev/null +++ b/commands/run_dpo.sh @@ -0,0 +1,58 @@ +#!/bin/bash +# This script runs an SFT example end-to-end on a tiny model using different possible configurations +# but defaults to QLoRA + PEFT +OUTPUT_DIR="test_dpo/" +MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" +DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style" +MAX_STEPS=5 +BATCH_SIZE=2 +SEQ_LEN=128 + +# Handle extra arguments in case one passes accelerate configs. +EXTRA_ACCELERATE_ARGS="" +EXTRA_TRAINING_ARGS="""--use_peft \ + --load_in_4bit +""" + +# This is a hack to get the number of available GPUs +NUM_GPUS=2 + +if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then + EXTRA_ACCELERATE_ARGS="" +else + EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG" + # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed + # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training. + if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then + EXTRA_TRAINING_ARGS="--fp16" + else + echo "Keeping QLoRA + PEFT" + fi +fi + + +CMD=""" +accelerate launch $EXTRA_ACCELERATE_ARGS \ + --num_processes $NUM_GPUS \ + --mixed_precision 'fp16' \ + `pwd`/trl/scripts/dpo.py \ + --model_name_or_path $MODEL_NAME \ + --dataset_name $DATASET_NAME \ + --output_dir $OUTPUT_DIR \ + --max_steps $MAX_STEPS \ + --per_device_train_batch_size $BATCH_SIZE \ + --max_length $SEQ_LEN \ + $EXTRA_TRAINING_ARGS +""" + +echo "Starting program..." + +{ # try + echo $CMD + eval "$CMD" +} || { # catch + # save log for exception + echo "Operation Failed!" + exit 1 +} +exit 0 diff --git a/commands/run_sft.sh b/commands/run_sft.sh new file mode 100644 index 0000000000000000000000000000000000000000..b7beaaf7fdd65bdaad239e66dbf1bf6262091392 --- /dev/null +++ b/commands/run_sft.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# This script runs an SFT example end-to-end on a tiny model using different possible configurations +# but defaults to QLoRA + PEFT +OUTPUT_DIR="test_sft/" +MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" +DATASET_NAME="stanfordnlp/imdb" +MAX_STEPS=5 +BATCH_SIZE=2 +SEQ_LEN=128 + + +# Handle extra arguments in case one passes accelerate configs. +EXTRA_ACCELERATE_ARGS="" +EXTRA_TRAINING_ARGS="""--use_peft \ + --load_in_4bit +""" + +# Set your number of GPUs here +NUM_GPUS=2 + +if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then + EXTRA_ACCELERATE_ARGS="" +else + EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG" + # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed + # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training. + if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then + EXTRA_TRAINING_ARGS="--fp16" + else + echo "Keeping QLoRA + PEFT" + fi +fi + + +CMD=""" +accelerate launch $EXTRA_ACCELERATE_ARGS \ + --num_processes $NUM_GPUS \ + --mixed_precision 'fp16' \ + `pwd`/trl/scripts/sft.py \ + --model_name $MODEL_NAME \ + --dataset_name $DATASET_NAME \ + --output_dir $OUTPUT_DIR \ + --max_steps $MAX_STEPS \ + --per_device_train_batch_size $BATCH_SIZE \ + --max_length $SEQ_LEN \ + $EXTRA_TRAINING_ARGS +""" + +echo "Starting program..." + +{ # try + echo $CMD + eval "$CMD" +} || { # catch + # save log for exception + echo "Operation Failed!" + exit 1 +} +exit 0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000000000000000000000000000000000..3ae2e502baf7b0c706e290e4d46d9c952139da60 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,5 @@ +services: + workspace: + build: + context: . + dockerfile: Dockerfile diff --git a/docker/trl-latest-gpu/Dockerfile b/docker/trl-latest-gpu/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..6c53033961932a36bc551f94e7307079903fb944 --- /dev/null +++ b/docker/trl-latest-gpu/Dockerfile @@ -0,0 +1,66 @@ +# Builds GPU docker image of PyTorch +# Uses multi-staged approach to reduce size +# Stage 1 +# Use base conda image to reduce time +FROM continuumio/miniconda3:latest AS compile-image +# Specify py version +ENV PYTHON_VERSION=3.10 +# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN apt-get update && \ + apt-get install -y curl git wget software-properties-common git-lfs && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Install audio-related libraries +RUN apt-get update && \ + apt install -y ffmpeg + +RUN apt install -y libsndfile1-dev +RUN git lfs install + +# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip +RUN python3 -m pip install --no-cache-dir --upgrade pip + +# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +# We don't install pytorch here yet since CUDA isn't available +# instead we use the direct torch wheel +ENV PATH /opt/conda/envs/trl/bin:$PATH +# Activate our bash shell +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] + +# Stage 2 +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image +COPY --from=compile-image /opt/conda /opt/conda +ENV PATH /opt/conda/bin:$PATH + +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] +RUN source activate trl && \ + python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq + +# Install apt libs +RUN apt-get update && \ + apt-get install -y curl git wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Activate the conda env and install transformers + accelerate from source +RUN source activate trl && \ + python3 -m pip install -U --no-cache-dir \ + librosa \ + "soundfile>=0.12.1" \ + scipy \ + transformers \ + accelerate \ + peft \ + trl[test]@git+https://github.com/huggingface/trl + +RUN source activate trl && \ + pip freeze | grep trl + +RUN echo "source activate trl" >> ~/.profile + +# Activate the virtualenv +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docker/trl-source-gpu/Dockerfile b/docker/trl-source-gpu/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..21e72504373c5ada24d2eb608eaf32042faae06d --- /dev/null +++ b/docker/trl-source-gpu/Dockerfile @@ -0,0 +1,66 @@ +# Builds GPU docker image of PyTorch +# Uses multi-staged approach to reduce size +# Stage 1 +# Use base conda image to reduce time +FROM continuumio/miniconda3:latest AS compile-image +# Specify py version +ENV PYTHON_VERSION=3.10 +# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN apt-get update && \ + apt-get install -y curl git wget software-properties-common git-lfs && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Install audio-related libraries +RUN apt-get update && \ + apt install -y ffmpeg + +RUN apt install -y libsndfile1-dev +RUN git lfs install + +# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip +RUN python3 -m pip install --no-cache-dir --upgrade pip + +# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +# We don't install pytorch here yet since CUDA isn't available +# instead we use the direct torch wheel +ENV PATH /opt/conda/envs/trl/bin:$PATH +# Activate our bash shell +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] + +# Stage 2 +FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image +COPY --from=compile-image /opt/conda /opt/conda +ENV PATH /opt/conda/bin:$PATH + +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] +RUN source activate trl && \ + python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq + +# Install apt libs +RUN apt-get update && \ + apt-get install -y curl git wget && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Activate the conda env and install transformers + accelerate from source +RUN source activate trl && \ + python3 -m pip install -U --no-cache-dir \ + librosa \ + "soundfile>=0.12.1" \ + scipy \ + git+https://github.com/huggingface/transformers \ + git+https://github.com/huggingface/accelerate \ + git+https://github.com/huggingface/peft \ + trl[test]@git+https://github.com/huggingface/trl + +RUN source activate trl && \ + pip freeze | grep transformers + +RUN echo "source activate trl" >> ~/.profile + +# Activate the virtualenv +CMD ["/bin/bash"] \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml new file mode 100644 index 0000000000000000000000000000000000000000..6d4a0cac292c91a9ebca502648b74a20a3b64b0b --- /dev/null +++ b/docs/source/_toctree.yml @@ -0,0 +1,116 @@ +- sections: + - local: index + title: TRL + - local: installation + title: Installation + - local: quickstart + title: Quickstart + title: Getting started +- sections: + - local: dataset_formats + title: Dataset Formats + - local: how_to_train + title: Training FAQ + - local: logging + title: Understanding Logs + title: Conceptual Guides +- sections: + - local: clis + title: Command Line Interface (CLI) + - local: customization + title: Customizing the Training + - local: reducing_memory_usage + title: Reducing Memory Usage + - local: speeding_up_training + title: Speeding Up Training + - local: distributing_training + title: Distributing Training + - local: use_model + title: Using Trained Models + title: How-to guides +- sections: + - local: deepspeed_integration + title: DeepSpeed + - local: liger_kernel_integration + title: Liger Kernel + - local: peft_integration + title: PEFT + - local: unsloth_integration + title: Unsloth + - local: vllm_integration + title: vLLM + title: Integrations +- sections: + - local: example_overview + title: Example Overview + - local: community_tutorials + title: Community Tutorials + - local: sentiment_tuning + title: Sentiment Tuning + - local: using_llama_models + title: Training StackLlama + - local: detoxifying_a_lm + title: Detoxifying a Language Model + - local: multi_adapter_rl + title: Multi Adapter RLHF + - local: training_vlm_sft + title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) + title: Examples +- sections: + - sections: # Sorted alphabetically + - local: alignprop_trainer + title: AlignProp + - local: bco_trainer + title: BCO + - local: cpo_trainer + title: CPO + - local: ddpo_trainer + title: DDPO + - local: dpo_trainer + title: DPO + - local: online_dpo_trainer + title: Online DPO + - local: gkd_trainer + title: GKD + - local: grpo_trainer + title: GRPO + - local: kto_trainer + title: KTO + - local: nash_md_trainer + title: Nash-MD + - local: orpo_trainer + title: ORPO + - local: ppo_trainer + title: PPO + - local: prm_trainer + title: PRM + - local: reward_trainer + title: Reward + - local: rloo_trainer + title: RLOO + - local: sft_trainer + title: SFT + - local: iterative_sft_trainer + title: Iterative SFT + - local: xpo_trainer + title: XPO + title: Trainers + - local: models + title: Model Classes + - local: model_utils + title: Model Utilities + - local: best_of_n + title: Best of N Sampling + - local: judges + title: Judges + - local: callbacks + title: Callbacks + - local: data_utils + title: Data Utilities + - local: rewards + title: Reward Functions + - local: script_utils + title: Script Utilities + - local: others + title: Others + title: API diff --git a/docs/source/alignprop_trainer.md b/docs/source/alignprop_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..4c3b21042c61128ed8031d4a6c56b6407539e004 --- /dev/null +++ b/docs/source/alignprop_trainer.md @@ -0,0 +1,93 @@ +# Aligning Text-to-Image Diffusion Models with Reward Backpropagation + +[![](https://img.shields.io/badge/All_models-AlignProp-blue)](https://huggingface.co/models?other=alignprop,trl) + +## The why + +If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO. +AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation. + +
+ + +## Getting started with `examples/scripts/alignprop.py` + +The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`). + +**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1. + +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running + +```batch +python alignprop.py --hf_user_access_token +``` + +To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` + +The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) + +- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps) +- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False + +## Setting up the image logging hook function + +Expect the function to be given a dictionary with keys +```python +['image', 'prompt', 'prompt_metadata', 'rewards'] + +``` +and `image`, `prompt`, `prompt_metadata`, `rewards`are batched. +You are free to log however you want the use of `wandb` or `tensorboard` is recommended. + +### Key terms + +- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process +- `prompt` : The prompt is the text that is used to generate the image +- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) +- `image` : The image generated by the Stable Diffusion model + +Example code for logging sampled images with `wandb` is given below. + +```python +# for logging these images to wandb + +def image_outputs_hook(image_data, global_step, accelerate_logger): + # For the sake of this example, we only care about the last batch + # hence we extract the last element of the list + result = {} + images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']] + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] + accelerate_logger.log_images( + result, + step=global_step, + ) + +``` + +### Using the finetuned model + +Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows + +```python +from diffusers import StableDiffusionPipeline +pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipeline.to("cuda") + +pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics') + +prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] +results = pipeline(prompts) + +for prompt, image in zip(prompts,results.images): + image.save(f"dump/{prompt}.png") +``` + +## Credits + +This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation + by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739). diff --git a/docs/source/bco_trainer.md b/docs/source/bco_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..e449f86b63d95c288af3359d80413f56fc5b7c72 --- /dev/null +++ b/docs/source/bco_trainer.md @@ -0,0 +1,100 @@ +# BCO Trainer + +[![](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl) + +TRL supports the Binary Classifier Optimization (BCO). +The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. +For a full example have a look at [`examples/scripts/bco.py`]. + +## Expected dataset type + +The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference). +The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Expected model format +The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. + +## Using the `BCOTrainer` + +For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. + +The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). + + + +```py +training_args = BCOConfig( + beta=0.1, +) + +bco_trainer = BCOTrainer( + model, + model_ref, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, +) +``` +After this one can then call: + +```py +bco_trainer.train() +``` + +## Underlying Distribution matching (UDM) + +In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts. +Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts. +If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM. + +Choose an embedding model and tokenizer: + +```py +embedding_model = AutoModel.from_pretrained(your_model_id) +embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id) + +# customize this function depending on your embedding model +def embed_prompt(input_ids, attention_mask, model): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + return outputs.last_hidden_state.mean(dim=1) + +embedding_model = Accelerator().prepare_model(self.embedding_model) +embedding_func = partial(embed_prompt, model=embedding_model) +``` + +Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: + +```py +training_args = BCOConfig( + beta=0.1, + prompt_sample_size=512, +) + +bco_trainer = BCOTrainer( + model, + model_ref, + args=training_args, + train_dataset=train_dataset, + processing_class=tokenizer, + embedding_func=embedding_func, + embedding_tokenizer=self.embedding_tokenizer, +) + +bco_trainer.train() +``` + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). + +## BCOTrainer + +[[autodoc]] BCOTrainer + +## BCOConfig + +[[autodoc]] BCOConfig diff --git a/docs/source/best_of_n.md b/docs/source/best_of_n.md new file mode 100644 index 0000000000000000000000000000000000000000..8b2978c2a38db8dba30c08956988e19ac95e2cd2 --- /dev/null +++ b/docs/source/best_of_n.md @@ -0,0 +1,72 @@ +# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning + +Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output. +As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example + +## Usage + +To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries + +```python + +from transformers import pipeline, AutoTokenizer +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + +ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) +reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device) +tokenizer = AutoTokenizer.from_pretrained(ref_model_name) +tokenizer.pad_token = tokenizer.eos_token + + +# callable that takes a list of raw text and returns a list of corresponding reward scores +def queries_to_scores(list_of_strings): + return [output["score"] for output in reward_pipe(list_of_strings)] + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler) + + +``` + +And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method + +```python + +best_of_n.generate(query_tensors, device=device, **gen_kwargs) + +``` +The default sample size is 4, but you can change it at the time of instance initialization like so + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8) + +``` + +The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization + +```python + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2) + +``` + +There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method. +This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization + +```python + +from transformers import GenerationConfig + +generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id) + +best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config) + +best_of_n.generate(query_tensors, device=device) + +``` + +Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query + + diff --git a/docs/source/callbacks.md b/docs/source/callbacks.md new file mode 100644 index 0000000000000000000000000000000000000000..7959a26b9de0982e071558bd01f96bf1290841b9 --- /dev/null +++ b/docs/source/callbacks.md @@ -0,0 +1,21 @@ +# Callbacks + +## SyncRefModelCallback + +[[autodoc]] SyncRefModelCallback + +## RichProgressCallback + +[[autodoc]] RichProgressCallback + +## WinRateCallback + +[[autodoc]] WinRateCallback + +## LogCompletionsCallback + +[[autodoc]] LogCompletionsCallback + +## MergeModelCallback + +[[autodoc]] MergeModelCallback \ No newline at end of file diff --git a/docs/source/clis.md b/docs/source/clis.md new file mode 100644 index 0000000000000000000000000000000000000000..0938dec2620b345e2c6f6d2ca975f88f7846e264 --- /dev/null +++ b/docs/source/clis.md @@ -0,0 +1,272 @@ +# Command Line Interfaces (CLIs) + +TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly. + +Currently supported commands are: + +#### Training Commands + +- `trl dpo`: fine-tune a LLM with DPO +- `trl grpo`: fine-tune a LLM with GRPO +- `trl kto`: fine-tune a LLM with KTO +- `trl sft`: fine-tune a LLM with SFT + +#### Other Commands + +- `trl env`: get the system information +- `trl vllm-serve`: serve a model with vLLM + +## Fine-Tuning with the TRL CLI + +### Basic Usage + +You can launch training directly from the CLI by specifying required arguments like the model and dataset: + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf +``` + + + + +### Using Configuration Files + +To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file: + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + + +### Scaling Up with Accelerate + +TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI. + +You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch). + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --num_processes 4 +``` + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +num_processes: 4 +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --num_processes 4 +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +num_processes: 4 +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + +### Using `--accelerate_config` for Accelerate Configuration + +The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either: + +* the name of a predefined config profile (built into TRL), or +* a path to a custom Accelerate YAML config file. + +#### Predefined Config Profiles + +TRL provides several ready-to-use Accelerate configs to simplify common training setups: + +| Name | Description | +| ------------ | ----------------------------------- | +| `fsdp1` | Fully Sharded Data Parallel Stage 1 | +| `fsdp2` | Fully Sharded Data Parallel Stage 2 | +| `zero1` | DeepSpeed ZeRO Stage 1 | +| `zero2` | DeepSpeed ZeRO Stage 2 | +| `zero3` | DeepSpeed ZeRO Stage 3 | +| `multi_gpu` | Multi-GPU training | +| `single_gpu` | Single-GPU training | + +To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`. + +#### Example Usage + + + + +```bash +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + + + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + + +```bash +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` + + + +## Getting the System Information + +You can get the system information by running the following command: + +```bash +trl env +``` + +This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed. + +```txt +Copy-paste the following information when reporting an issue: + +- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31 +- Python version: 3.11.9 +- PyTorch version: 2.4.1 +- accelerator(s): NVIDIA H100 80GB HBM3 +- Transformers version: 4.45.0.dev0 +- Accelerate version: 0.34.2 +- Accelerate config: + - compute_environment: LOCAL_MACHINE + - distributed_type: DEEPSPEED + - mixed_precision: no + - use_cpu: False + - debug: False + - num_processes: 4 + - machine_rank: 0 + - num_machines: 1 + - rdzv_backend: static + - same_network: True + - main_training_function: main + - enable_cpu_affinity: False + - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2} + - downcast_bf16: no + - tpu_use_cluster: False + - tpu_use_sudo: False + - tpu_env: [] +- Datasets version: 3.0.0 +- HF Hub version: 0.24.7 +- TRL version: 0.12.0.dev0+acb4d70 +- bitsandbytes version: 0.41.1 +- DeepSpeed version: 0.15.1 +- Diffusers version: 0.30.3 +- Liger-Kernel version: 0.3.0 +- LLM-Blender version: 0.0.2 +- OpenAI version: 1.46.0 +- PEFT version: 0.12.0 +- vLLM version: not installed +``` + +This information is required when reporting an issue. diff --git a/docs/source/community_tutorials.md b/docs/source/community_tutorials.md new file mode 100644 index 0000000000000000000000000000000000000000..f19742cb6a2c633f8be3a93a4537b84d2f934ccd --- /dev/null +++ b/docs/source/community_tutorials.md @@ -0,0 +1,32 @@ +# Community Tutorials + +Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities. + +# Language Models + +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) | +| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | +| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) | +| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) | +| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) | +| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) | +| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) | + + + +# Vision Language Models + +| Task | Class | Description | Author | Tutorial | Colab | +| --- | --- | --- | --- | --- | --- | +| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) | +| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) | +| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) | +| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | +| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) | + +## Contributing + +If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community. diff --git a/docs/source/cpo_trainer.md b/docs/source/cpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..24e0f3fdaee4bf98bdfad8b1bf70e9a741a2e9af --- /dev/null +++ b/docs/source/cpo_trainer.md @@ -0,0 +1,108 @@ +# CPO Trainer + +[![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl) + +## Overview + +Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat. + +CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective. + +## Quick start + +This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_cpo.py +from datasets import load_dataset +from trl import CPOConfig, CPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10) +trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_cpo.py +``` + +## Expected dataset type + +CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Example script + +We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) + +To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch examples/scripts/cpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-CPO +``` + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses + +## CPO variants + +### Simple Preference Optimization (SimPO) + +The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`]. + +### CPO-SimPO + +We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`]. + +## Loss functions + +The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported: + +| `loss_type=` | Description | +| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | +| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | +| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). | + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +## CPOTrainer + +[[autodoc]] CPOTrainer + +## CPOConfig + +[[autodoc]] CPOConfig diff --git a/docs/source/customization.md b/docs/source/customization.md new file mode 100644 index 0000000000000000000000000000000000000000..1cfcdc27b5df096e2b84530da608ffc50e8f803c --- /dev/null +++ b/docs/source/customization.md @@ -0,0 +1,121 @@ +# Training customization + +TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers. + + + +## Use different optimizers and schedulers + +By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows: + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import optim +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + optimizers=(optimizer, None), +) +trainer.train() +``` + +### Add a learning rate scheduler + +You can also play with your training by adding learning rate schedulers. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import optim +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate) +lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) + +trainer = DPOTrainer( + model=model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, + optimizers=(optimizer, lr_scheduler), +) +trainer.train() +``` + +## Memory efficient fine-tuning by sharing layers + +Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train. + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import create_reference_model, DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +ref_model = create_reference_model(model, num_shared_layers=6) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +``` + +## Pass 8-bit reference models + +Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning. + +Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit). + +```python +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from trl import DPOConfig, DPOTrainer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +quantization_config = BitsAndBytesConfig(load_in_8bit=True) +ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") +dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") +training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO") + +trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_dataset=dataset, + tokenizer=tokenizer, +) +trainer.train() +``` + +## Use the accelerator cache optimizer + +When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`: + +```python +training_args = DPOConfig(..., optimize_device_cache=True) +``` diff --git a/docs/source/data_utils.md b/docs/source/data_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..e4acfbb41b80304aaea4c7255fb58111e94f31bb --- /dev/null +++ b/docs/source/data_utils.md @@ -0,0 +1,41 @@ +# Data Utilities + +## is_conversational + +[[autodoc]] is_conversational + +## apply_chat_template + +[[autodoc]] apply_chat_template + +## maybe_apply_chat_template + +[[autodoc]] maybe_apply_chat_template + +## maybe_convert_to_chatml + +[[autodoc]] maybe_convert_to_chatml + +## extract_prompt + +[[autodoc]] extract_prompt + +## maybe_extract_prompt + +[[autodoc]] maybe_extract_prompt + +## unpair_preference_dataset + +[[autodoc]] unpair_preference_dataset + +## maybe_unpair_preference_dataset + +[[autodoc]] maybe_unpair_preference_dataset + +## pack_dataset + +[[autodoc]] pack_dataset + +## truncate_dataset + +[[autodoc]] truncate_dataset diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md new file mode 100644 index 0000000000000000000000000000000000000000..f694629875275ec9eb84b9f09cdc83b7d638224c --- /dev/null +++ b/docs/source/dataset_formats.md @@ -0,0 +1,938 @@ +# Dataset formats and types + +This guide provides an overview of the dataset formats and types supported by each trainer in TRL. + +## Overview of the dataset formats and types + +- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*. +- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Type \ FormatStandardConversational
Language modeling +
{"text": "The sky is blue."}
+
+
{"messages": [{"role": "user", "content": "What color is the sky?"},
+              {"role": "assistant", "content": "It is blue."}]}
+
Prompt-only +
{"prompt": "The sky is"}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}]}
+
Prompt-completion +
{"prompt": "The sky is",
+ "completion": " blue."}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "completion": [{"role": "assistant", "content": "It is blue."}]}
+
Preference +
{"prompt": "The sky is",
+ "chosen": " blue.",
+ "rejected": " green."}
+ or, with implicit prompt: +
{"chosen": "The sky is blue.",
+ "rejected": "The sky is green."}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "chosen": [{"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "assistant", "content": "It is green."}]}
+ or, with implicit prompt: +
{"chosen": [{"role": "user", "content": "What color is the sky?"},
+              {"role": "assistant", "content": "It is blue."}],
+ "rejected": [{"role": "user", "content": "What color is the sky?"},
+                {"role": "assistant", "content": "It is green."}]}
+
Unpaired preference +
{"prompt": "The sky is",
+ "completion": " blue.",
+ "label": True}
+
+
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
+ "completion": [{"role": "assistant", "content": "It is green."}],
+ "label": False}
+
Stepwise supervision +
{"prompt": "Which number is larger, 9.8 or 9.11?",
+ "completions": ["The fractional part of 9.8 is 0.8.", 
+                 "The fractional part of 9.11 is 0.11.",
+                 "0.11 is greater than 0.8.",
+                 "Hence, 9.11 > 9.8."],
+ "labels": [True, True, False, False]}
+
+ +### Formats + +#### Standard + +The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks: + +```python +# Language modeling +language_modeling_example = {"text": "The sky is blue."} +# Preference +preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} +# Unpaired preference +unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True} +``` + +#### Conversational + +Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text). + +```python +messages = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, + {"role": "user", "content": "I'd like to show off how chat templating works!"}, +] +``` + +Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks: + +```python +# Prompt-completion +prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}]} +# Preference +preference_example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], +} +``` + +Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section. + +### Types + +#### Language modeling + +A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text. + +```python +# Standard format +language_modeling_example = {"text": "The sky is blue."} +# Conversational format +language_modeling_example = {"messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."} +]} +``` + +#### Prompt-only + +In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input. + +```python +# Standard format +prompt_only_example = {"prompt": "The sky is"} +# Conversational format +prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} +``` + +For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368). + + + +While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type: + +```python +from transformers import AutoTokenizer +from trl import apply_chat_template + +tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") + +# Example for prompt-only type +prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} +apply_chat_template(prompt_only_example, tokenizer) +# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'} + +# Example for language modeling type +lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]} +apply_chat_template(lm_example, tokenizer) +# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'} +``` + +- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion. +- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content. + + + +#### Prompt-completion + +A prompt-completion dataset includes a `"prompt"` and a `"completion"`. + +```python +# Standard format +prompt_completion_example = {"prompt": "The sky is", "completion": " blue."} +# Conversational format +prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}]} +``` + +For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216). + +#### Preference + +A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response. +Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible. + +```python +# Standard format +## Explicit prompt (recommended) +preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} +# Implicit prompt +preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."} + +# Conversational format +## Explicit prompt (recommended) +preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}]} +## Implicit prompt +preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}]} +``` + +For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c). + +Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets. + +#### Unpaired preference + +An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not. + +```python +# Standard format +unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True} +# Conversational format +unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True} +``` + +For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf). + +#### Stepwise supervision + +A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process. + +```python +stepwise_example = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."], + "labels": [True, False] +} +``` + +For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e). + +## Which dataset type to use? + +Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer. + +| Trainer | Expected dataset type | +| ----------------------- | ------------------------------------------------------------------------------------------------------ | +| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) | +| [`GRPOTrainer`] | [Prompt-only](#prompt-only) | +| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) | +| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) | +| [`NashMDTrainer`] | [Prompt-only](#prompt-only) | +| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | +| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | +| [`PPOTrainer`] | Tokenized language modeling | +| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | +| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | +| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) | +| [`XPOTrainer`] | [Prompt-only](#prompt-only) | + + + +TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format. +For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section. + + + +## Working with conversational datasets in TRL + +Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format. +Fortunately, TRL offers tools to easily handle this conversion, which are detailed below. + +### Converting a conversational dataset into a standard dataset + +To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use. + +For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating). + +In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it: + +```python +from transformers import AutoTokenizer +from trl import apply_chat_template + +tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") + +example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}] +} + +apply_chat_template(example, tokenizer) +# Output: +# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'} +``` + +Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset: + +```python +from datasets import Dataset +from trl import apply_chat_template + +dataset_dict = { + "prompt": [[{"role": "user", "content": "What color is the sky?"}], + [{"role": "user", "content": "Where is the sun?"}]], + "completion": [[{"role": "assistant", "content": "It is blue."}], + [{"role": "assistant", "content": "In the sky."}]] +} + +dataset = Dataset.from_dict(dataset_dict) +dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) +# Output: +# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', +# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'], +# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']} +``` + + + +We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation. +For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks. + + + + + +It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output: + +```python +apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")) +# Output: +# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n', +# 'completion': 'It is blue.<|im_end|>\n'} +``` + +Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results. + + + +## Using any dataset with TRL: preprocessing and conversion + +Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format. + +To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions. + +### Example: UltraFeedback dataset + +Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset: + + + +As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty"). + +By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub: + +```sh +python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness +``` + +Once converted, the dataset will look like this: + + + +Now, you can use this dataset with TRL! + +By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL. + +## Utilities for converting dataset types + +This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently. + +For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification. + +| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision | +| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- | +| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A | +| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A | +| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A | +| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A | +| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A | +| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A | +| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A | + +### From prompt-completion to language modeling dataset + +To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "completion": [" blue.", " in the sky."], +}) + +def concat_prompt_completion(example): + return {"text": example["prompt"] + example["completion"]} + +dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From prompt-completion to prompt-only dataset + +To convert a prompt-completion dataset into a prompt-only dataset, remove the completion. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "completion": [" blue.", " in the sky."], +}) + +dataset = dataset.remove_columns("completion") +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From preference with implicit prompt to language modeling dataset + +To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "chosen": ["The sky is blue.", "The sun is in the sky."], + "rejected": ["The sky is green.", "The sun is in the sea."], +}) + +dataset = dataset.rename_column("chosen", "text").remove_columns("rejected") +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From preference with implicit prompt to prompt-completion dataset + +To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) +dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion") +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]} +``` + +### From preference with implicit prompt to prompt-only dataset + +To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) +dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]} +``` + +### From implicit to explicit prompt preference dataset + +To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`]. + +```python +from datasets import Dataset +from trl import extract_prompt + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = dataset.map(extract_prompt) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]} +``` + +### From preference with implicit prompt to unpaired preference dataset + +To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`]. + +```python +from datasets import Dataset +from trl import extract_prompt, unpair_preference_dataset + +dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}], + [{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = dataset.map(extract_prompt) +dataset = unpair_preference_dataset(dataset) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'completion': [{'role': 'assistant', 'content': 'It is blue.'}], + 'label': True} +``` + + + +Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad. +Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad. +This can be ensured by checking absolute rating of each completion, e.g. from a reward model. + + + +### From preference to language modeling dataset + +To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +def concat_prompt_chosen(example): + return {"text": example["prompt"] + example["chosen"]} + +dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From preference to prompt-completion dataset + +To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion") +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is', 'completion': ' blue.'} +``` + +### From preference to prompt-only dataset + +To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], +}) + +dataset = dataset.remove_columns(["chosen", "rejected"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From explicit to implicit prompt preference dataset + +To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What color is the sky?"}], + [{"role": "user", "content": "Where is the sun?"}], + ], + "chosen": [ + [{"role": "assistant", "content": "It is blue."}], + [{"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "assistant", "content": "It is green."}], + [{"role": "assistant", "content": "In the sea."}], + ], +}) + +def concat_prompt_to_completions(example): + return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]} + +dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt") +``` + +```python +>>> dataset[0] +{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]} +``` + +### From preference to unpaired preference dataset + +To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`]. + +```python +from datasets import Dataset +from trl import unpair_preference_dataset + +dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What color is the sky?"}], + [{"role": "user", "content": "Where is the sun?"}], + ], + "chosen": [ + [{"role": "assistant", "content": "It is blue."}], + [{"role": "assistant", "content": "In the sky."}], + ], + "rejected": [ + [{"role": "assistant", "content": "It is green."}], + [{"role": "assistant", "content": "In the sea."}], + ], +}) + +dataset = unpair_preference_dataset(dataset) +``` + +```python +>>> dataset[0] +{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'completion': [{'role': 'assistant', 'content': 'It is blue.'}], + 'label': True} +``` + + + +Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad. +Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad. +This can be ensured by checking absolute rating of each completion, e.g. from a reward model. + + + +### From unpaired preference to language modeling dataset + +To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +def concatenate_prompt_completion(example): + return {"text": example["prompt"] + example["completion"]} + +dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"]) +``` + +```python +>>> dataset[0] +{'text': 'The sky is blue.'} +``` + +### From unpaired preference to prompt-completion dataset + +To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is', 'completion': ' blue.'} +``` + +### From unpaired preference to prompt-only dataset + +To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], +}) + +dataset = dataset.remove_columns(["completion", "label"]) +``` + +```python +>>> dataset[0] +{'prompt': 'The sky is'} +``` + +### From stepwise supervision to language modeling dataset + +To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def concatenate_prompt_completions(example): + completion = "".join(example["completions"]) + return {"text": example["prompt"] + completion} + +dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"]) +``` + +```python +>>> dataset[0] +{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'} +``` + +### From stepwise supervision to prompt completion dataset + +To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def join_completions(example): + completion = "".join(example["completions"]) + return {"completion": completion} + +dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'} +``` + +### From stepwise supervision to prompt only dataset + +To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +dataset = dataset.remove_columns(["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light'} +``` + +### From stepwise supervision to unpaired preference dataset + +To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels. + +The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence. + +```python +from datasets import Dataset + +dataset = Dataset.from_dict({ + "prompt": ["Blue light", "Water"], + "completions": [[" scatters more in the atmosphere,", " so the sky is green."], + [" forms a less dense structure in ice,", " which causes it to expand when it freezes."]], + "labels": [[True, False], [True, True]], +}) + +def merge_completions_and_labels(example): + return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])} + +dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"]) +``` + +```python +>>> dataset[0] +{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False} +``` + +## Vision datasets + +Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently. + +A conversational vision dataset differs from a standard conversational dataset in two key ways: + +1. The dataset must contain the key `images` with the image data. +2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`. + +Example: + +```python +# Textual dataset: +"content": "What color is the sky?" + +# Vision dataset: +"content": [ + {"type": "image"}, + {"type": "text", "text": "What color is the sky in the image?"} +] +``` + +An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly: + + + diff --git a/docs/source/ddpo_trainer.md b/docs/source/ddpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..eca557c9e4a314c58f730c8af4d74cd83b3a18cd --- /dev/null +++ b/docs/source/ddpo_trainer.md @@ -0,0 +1,131 @@ +# Denoising Diffusion Policy Optimization + +[![](https://img.shields.io/badge/All_models-DDPO-blue)](https://huggingface.co/models?other=ddpo,trl) + +## The why + +| Before | After DDPO finetuning | +| --- | --- | +|
|
| +|
|
| +|
|
| + + +## Getting started with Stable Diffusion finetuning with reinforcement learning + +The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers` +library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers. +Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made. + +There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.** +There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide. + +The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO). + +For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py) + +Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training. + +Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images. + +## Getting started with `examples/scripts/ddpo.py` + +The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`). + +**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor. + +Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running + +```batch +python ddpo.py --hf_user_access_token +``` + +To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help` + +The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script) + +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`) +- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count + +## Setting up the image logging hook function + +Expect the function to be given a list of lists of the form +```python +[[image, prompt, prompt_metadata, rewards, reward_metadata], ...] + +``` +and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched. +The last list in the lists of lists represents the last sample batch. You are likely to want to log this one +While you are free to log however you want the use of `wandb` or `tensorboard` is recommended. + +### Key terms + +- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process +- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward +- `prompt` : The prompt is the text that is used to generate the image +- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45) +- `image` : The image generated by the Stable Diffusion model + +Example code for logging sampled images with `wandb` is given below. + +```python +# for logging these images to wandb + +def image_outputs_hook(image_data, global_step, accelerate_logger): + # For the sake of this example, we only care about the last batch + # hence we extract the last element of the list + result = {} + images, prompts, _, rewards, _ = image_data[-1] + for i, image in enumerate(images): + pil = Image.fromarray( + (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + ) + pil = pil.resize((256, 256)) + result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil] + accelerate_logger.log_images( + result, + step=global_step, + ) + +``` + +### Using the finetuned model + +Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows + +```python + +import torch +from trl import DefaultDDPOStableDiffusionPipeline + +pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model") + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + +# memory optimization +pipeline.vae.to(device, torch.float16) +pipeline.text_encoder.to(device, torch.float16) +pipeline.unet.to(device, torch.float16) + +prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"] +results = pipeline(prompts) + +for prompt, image in zip(prompts,results.images): + image.save(f"{prompt}.png") + +``` + +## Credits + +This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models +with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301). + +## DDPOTrainer + +[[autodoc]] DDPOTrainer + +## DDPOConfig + +[[autodoc]] DDPOConfig + diff --git a/docs/source/deepspeed_integration.md b/docs/source/deepspeed_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..0f6980656a3544c774307aa27e97d1d007a10737 --- /dev/null +++ b/docs/source/deepspeed_integration.md @@ -0,0 +1,39 @@ +# DeepSpeed Integration + + + +Section under construction. Feel free to contribute! + + + +TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more. + +DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency. + +![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png) + +## Installation + +To use DeepSpeed with TRL, install it using the following command: + +```bash +pip install deepspeed +``` + +## Running Training Scripts with DeepSpeed + +No modifications to your training script are required. Simply run it with the DeepSpeed configuration file: + +```bash +accelerate launch --config_file train.py +``` + +We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command: + +```bash +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py +``` + +## Additional Resources + +Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. diff --git a/docs/source/detoxifying_a_lm.md b/docs/source/detoxifying_a_lm.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0ab5fd80481eee21ddccb34add5612d37e5e12 --- /dev/null +++ b/docs/source/detoxifying_a_lm.md @@ -0,0 +1,187 @@ +# Detoxifying a Language Model using PPO + +Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it. + +Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters! + +Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo: + +| File | Description | Colab link | +|---|---| --- | +| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x | +| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x | +| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x | + +## Context + +Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it. + +### Computing toxicity scores + +In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic. +Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier. +One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one. + +### Selection of models + +We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models: + +* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters) +* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters) +* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters) + +For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt). + +| Model | Mean toxicity score | +|---|---| +| `gpt2` | 0.01602 | +| `facebook/opt-350m` | 0.01628 | +| `bigscience/bloom-560m` | 0.00767 | +| `EleutherAI/gpt-neo-125M` | **0.02016** | + +## Designing the problem + +When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge. + +### Pre-processing the dataset + +The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score. + +A `prompt` example: +``` +{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 } +``` +And its `continuation` value: +``` +{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 } +``` + +We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code: +```python +train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train") + +def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + +train_dataset = train_dataset.filter(filter_fn, batched=False) +``` + +### Reward function + +The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not. +We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral". +```python +logits = toxicity_model(**toxicity_inputs).logits.float() +rewards = (logits[:, 0]).tolist() +``` + +### Impact of input prompts length + +We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts. +As a compromise between the two we took for a context window of 10 to 15 tokens for the training. + + +
+ +
+ +### How to deal with OOM issues + +Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU: + +- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2: + +```python +model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16) +``` + +and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`. + +- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this: + +
+ +
+ +```python +ref_model = create_reference_model(model, num_shared_layers=6) +trainer = PPOTrainer(..., ref_model=ref_model) +``` + +In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model). + +- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower). + +## Training the model! + +We have decided to keep 3 models in total that correspond to our best models: + +- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox) +- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox) +- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox) + +We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high): + +
+ +
+ +The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this: + +
+ +
+ +As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents. + +Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set: + +
+ +
+ +## Results + +We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity). +We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below: + +| Model | Mean toxicity score | Std toxicity score | +| --- | --- | --- | +| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 | +| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** | +| --- | --- | --- | +| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 | +| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** | +| --- | --- | --- | +| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 | +| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** | + +
+
+ +
Toxicity score with respect to the size of the model.
+
+
+ +Below are few generation examples of `gpt-j-6b-detox` model: + +
+ +
+ +The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py). + +### Discussions + +The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers). + +To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful. + +### Limitations + +We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use. + +## What is next? + +You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms). diff --git a/docs/source/distributing_training.md b/docs/source/distributing_training.md new file mode 100644 index 0000000000000000000000000000000000000000..b3d8814f9410a70d52715d05e5c978c86ebb4a18 --- /dev/null +++ b/docs/source/distributing_training.md @@ -0,0 +1,60 @@ +# Distributing Training + + +Section under construction. Feel free to contribute! + + +## Multi-GPU Training with TRL + +The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running + +```bash +accelerate config +``` + +and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running: + +```bash +accelerate launch train.py +``` + +We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.: + +```shell +accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py +``` + +This automatically distributes the workload across all available GPUs. + +Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process: +- Processes its own batch of data +- Computes the loss and gradients for that batch +- Shares gradient updates across all GPUs + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png) + +The effective batch size is calculated as: + +$$ +\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps} +$$ + +To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly. + +Example, these configurations are equivalent, and should yield the same results: + +| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments | +| --- | --- | --- | --- | +| 1 | 32 | 1 | Possibly high memory usage, but faster training | +| 1 | 4 | 8 | Lower memory usage, slower training | +| 8 | 4 | 1 | Multi-GPU to get the best of both worlds | + + + +Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details. + + + +## Multi-Nodes Training + +We're working on a guide for multi-node training. Stay tuned! 🚀 \ No newline at end of file diff --git a/docs/source/dpo_trainer.md b/docs/source/dpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..9038134f07f94cdfd573525f5e5a0eefad0d5d8e --- /dev/null +++ b/docs/source/dpo_trainer.md @@ -0,0 +1,279 @@ +# DPO Trainer + +[![](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl) [![](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment) + +## Overview + +TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn). + +The abstract from the paper is the following: + +> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train. + +The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm. + +Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer): + +1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt. +2. **Optimization**: Maximize the log-likelihood of the DPO loss directly. + +This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)): + +![](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d) + +Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290). + +## Quick start + +This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_dpo.py +from datasets import load_dataset +from trl import DPOConfig, DPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10) +trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_dpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-DPO
+<shirin_yamani>:
+What is Huggingface?
+
+<trl-lib/Qwen2-0.5B-DPO>:
+Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in  Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000  pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
+
+ +## Expected dataset type + +DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. + +### Special considerations for vision-language models + +The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section. + +Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`. + +```diff +- model = AutoModelForCausalLM.from_pretrained(model_id) ++ model = AutoModelForVision2Seq.from_pretrained(model_id) + +- tokenizer = AutoTokenizer.from_pretrained(model_id) ++ processor = AutoProcessor.from_pretrained(model_id) + + trainer = DPOTrainer( + model, + args=training_args, + train_dataset=train_dataset, +- processing_class=tokenizer, ++ processing_class=processor, +) +``` + +For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py). + + +## Example script + +We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) + +To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch trl/scripts/dpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-DPO +``` + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta +- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta +- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards + +## Loss functions + +The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported: + +| `loss_type=` | Description | +| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. | +| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. | +| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). | +| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. | +| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. | +| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) | +| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. | +| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. | +| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. | +| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. | +| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). | + +### Label smoothing + +The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0). + +### Syncing the reference model + +The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`]. + +### RPO loss + +The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`. + +### WPO loss + +The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`]. + +### LD-DPO loss + +The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`. + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +## Accelerate DPO fine-tuning using `unsloth` + +You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below: + +| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved | +| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ | +| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% | +| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% | + +First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows: + +```diff + from datasets import load_dataset + from trl import DPOConfig, DPOTrainer +- from transformers import AutoModelForCausalLM, AutoTokenizer ++ from unsloth import FastLanguageModel + +- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") ++ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct") ++ model = FastLanguageModel.get_peft_model(model) + train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10) ++ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True) + trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) + trainer.train() + +``` + +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + +## Reference model considerations with PEFT + +You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA. + +1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient. +2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below. +3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls. + +### Downsides to merging QLoRA before DPO (approach 2) + +As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py). + +However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand). + +### Using option 3 - load the adapter twice + +To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`]. + +For example: + +```python +# Load the base model. +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", +) +model = AutoModelForCausalLM.from_pretrained( + "mistralai/mixtral-8x7b-v0.1", + load_in_4bit=True, + quantization_config=bnb_config, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", +) +model.config.use_cache = False + +# Load the adapter. +model = PeftModel.from_pretrained( + model, + "/path/to/peft", + is_trainable=True, + adapter_name="train", +) +# Load the adapter a second time, with a different name, which will be our reference model. +model.load_adapter("/path/to/peft", adapter_name="reference") + +# Initialize the trainer, without a ref_model param. +training_args = DPOConfig( + model_adapter_name="train", + ref_adapter_name="reference", +) +dpo_trainer = DPOTrainer( + model, + args=training_args, + ... +) +``` + +## DPOTrainer + +[[autodoc]] DPOTrainer + +## DPOConfig + +[[autodoc]] DPOConfig + +## DataCollatorForPreference + +[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md new file mode 100644 index 0000000000000000000000000000000000000000..779b9160e56ee5b6e3a96b43295c22f2c9948e85 --- /dev/null +++ b/docs/source/example_overview.md @@ -0,0 +1,89 @@ +# Examples + + +## Introduction + +The examples should work in any of the following settings (with the same script): + - single GPU + - multi GPUS (using PyTorch distributed mode) + - multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3) + - fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision) + +To run it in each of these various modes, first initialize the accelerate +configuration with `accelerate config` + +**NOTE to train with a 4-bit or 8-bit model**, please run + +```bash +pip install --upgrade trl[quantization] +``` + + +## Accelerate Config +For all the examples, you'll need to generate a 🤗 Accelerate config file with: + +```shell +accelerate config # will prompt you to define the training configuration +``` + +Then, it is encouraged to launch jobs with `accelerate launch`! + + +# Maintained Examples + +Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly. + +| File | Description | +| --- | --- | +| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. | +| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | +| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. | +| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. | +| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. | +| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. | +| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. | +| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | +| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language | +| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | +| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | +| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. | +| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language | +| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | +| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. | +| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. | +| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. | +| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. | +| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. | +| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. | + +Here are also some easier-to-run colab notebooks that you can use to get started with TRL: + +| File | Description | +| --- | --- | +| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. | + + +We also have some other examples that are less maintained but can be used as a reference: +1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.) + + +## Distributed training + +All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.) + +```shell +accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` + +You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision). + +### Distributed training with DeepSpeed + +Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`): + +```shell +accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script +``` diff --git a/docs/source/gkd_trainer.md b/docs/source/gkd_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..c4f82ff1605df5217b015ee77003a309d7a3a1c4 --- /dev/null +++ b/docs/source/gkd_trainer.md @@ -0,0 +1,98 @@ +# Generalized Knowledge Distillation Trainer + +[![](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl) + +## Overview + +Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem. + +The abstract from the paper is the following: + +> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning. + + +The key aspects of GKD are: +1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences. +2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun). + +## Usage tips + +The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely: +* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch. +* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher. +* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two. + +The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method. + +> [!WARNING] +> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture. + +The basic API is as follows: + +```python +from datasets import Dataset +from trl import GKDConfig, GKDTrainer +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, +) + +NUM_DUMMY_SAMPLES = 100 + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +# The model to optimise +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +# The teacher model to calculate the KL divergence against +teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct") + +train_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "Hi, how are you?"}, + {"role": "assistant", "content": "I'm great thanks"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) +eval_dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "user", "content": "What colour is the sky?"}, + {"role": "assistant", "content": "The sky is blue"}, + ] + ] + * NUM_DUMMY_SAMPLES + } +) + +training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1) +trainer = GKDTrainer( + model=model, + teacher_model=teacher_model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, + eval_dataset=eval_dataset, +) +trainer.train() +``` + +### Expected dataset type + +The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys: +* `role`: either `system`, `assistant` or `user` +* `content`: the message content + + +## GKDTrainer + +[[autodoc]] GKDTrainer + +## GKDConfig + +[[autodoc]] GKDConfig diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..5f0f9386b0528e52229731ba319479d16f63454f --- /dev/null +++ b/docs/source/grpo_trainer.md @@ -0,0 +1,518 @@ +# GRPO Trainer + +[![](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl) + +## Overview + +TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday). + +The abstract from the paper is the following: + +> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO. + +This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here: + + + +Below is the script to train the model. + +```python +# train_grpo.py +from datasets import load_dataset +from trl import GRPOConfig, GRPOTrainer + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Define the reward function, which rewards completions that are close to 20 characters +def reward_len(completions, **kwargs): + return [-abs(20 - len(completion)) for completion in completions] + +training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10) +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_len, + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_grpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 day. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png) + +## Looking deeper into the GRPO method + +GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png) + +### Generating completions + +At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)). + +### Computing the advantage + +For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows: + +$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$ + +This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**. + + + +It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`]. + + + +### Estimating the KL divergence + +KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows: + +$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i, + +Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types). + + + + + +Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value. + + + +In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**: + +$$ +\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right], +$$ + +where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\). +When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective. + +#### Loss Types + +Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows: + +$$ +\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + +where + +$$ +l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right]. +$$ + +The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length: + +$$ +\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + + +Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation: + +$$ +\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t}, +$$ + +This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`]. + +## Logged metrics + +- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. +- `completions/mean_length`: The average length of generated completions. +- `completions/min_length`: The minimun length of generated completions. +- `completions/max_length`: The maximum length of generated completions. +- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. +- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS. +- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. +- `completions/clipped_ratio` : The ratio of truncated (clipped) completions. +- `reward/{reward_func_name}/mean`: The average reward from a specific reward function. +- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function. +- `reward`: The overall average reward after applying reward weights. +- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights. +- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect). +- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero. +- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region: +$$ +\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,. +$$ +A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change. +- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\) +- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\) +- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\) +- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). + +## Customization + +### Speed up training with vLLM-powered generation + +Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with +```shell +pip install trl[vllm] +``` + +We support two ways of using vLLM during training: **server mode** and **colocate mode**. + +#### 🔌 Option 1: Server mode + +In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference. + +1. **Start the vLLM server**: + ```bash + trl vllm-serve --model + ``` + +2. **Enable server mode in your training script**: + ```python + from trl import GRPOConfig + + training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted + ) + ``` + + + +Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable. + + + +#### 🧩 Option 2: Colocate mode + +In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + +Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors. + + + +For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods). + +### GRPO at scale: train a 70B+ Model on multiple nodes + +When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include: + +- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration). +- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training). +- **vLLM**: See the previous section on how to use vLLM to speed up generation. + +Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation. + +```sh +#!/bin/bash +#SBATCH --nodes=5 +#SBATCH --gres=gpu:8 + +# Get the list of allocated nodes +NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) + +# Assign the first 4 nodes for training and the 5th node for vLLM +TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training +VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM + +# Run training on the first 4 nodes (Group 1) +srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + --num_processes 32 \ + --num_machines 4 \ + --main_process_ip ${NODELIST[0]} \ + --machine_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + train_grpo.py \ + --server_ip $VLLM_NODE & + +# Run vLLM server on the 5th node (Group 2) +srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 & + +wait +``` + +```python +import argparse + +from datasets import load_dataset +from trl import GRPOTrainer, GRPOConfig + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP") + args = parser.parse_args() + + # Example dataset from TLDR + dataset = load_dataset("trl-lib/tldr", split="train") + + # Dummy reward function: count the number of unique characters in the completions + def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + + training_args = GRPOConfig( + output_dir="Qwen2.5-72B-GRPO", + per_device_train_batch_size=4, + bf16=True, + gradient_checkpointing=True, + logging_steps=10, + use_vllm=True, + vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X + ) + + trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset) + trainer.train() + +if __name__=="__main__": + main() +``` + +### Using a custom reward function + +The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements: + +1. **Input arguments**: + - The function must accept the following as keyword arguments: + - `prompts` (contains the prompts), + - `completions` (contains the generated completions), + - `completions_ids` (contains the tokenized completions), + - All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument. + + The easiest way to comply with this requirement is to use `**kwargs` in the function signature. + - Depending on the dataset format, the input will vary: + - For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings. + - For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries. + +2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion. + +#### Example 1: Reward longer completions + +Below is an example of a reward function for a standard format that rewards longer completions: + +```python +def reward_func(completions_ids, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of token count).""" + return [float(len(ids)) for ids in completions_ids] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it +>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it +>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] +>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids) +[2.0, 4.0] +``` + +#### Example 1.1: Reward longer completions (based in the number of characters) + +Same as the previous example, but this time the reward function is based on the number of characters instead of tokens. + +```python +def reward_func(completions, **kwargs): + """Reward function that assigns higher scores to longer completions (in terms of character count).""" + return [float(len(completion)) for completion in completions] +``` + +You can test it as follows: + +```python +>>> prompts = ["The sky is", "The sun is"] +>>> completions = [" blue.", " in the sky."] +>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it +>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids) +[6.0, 12.0] +``` + +#### Example 2: Reward completions with specific format + +Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +It is designed for conversational format, where prompts and completions consist of structured messages. + +```python +import re + +def format_reward_func(completions, **kwargs): + """Reward function that checks if the completion has a specific format.""" + pattern = r"^.*?.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] +``` + +You can test this function as follows: + +```python +>>> prompts = [ +... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}], +... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}], +... ] +>>> completions = [ +... [{"role": "assistant", "content": "The sum of 1 and 2 is 3, which we multiply by 4 to get 12.(1 + 2) * 4 = 12"}], +... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}], +... ] +>>> format_reward_func(prompts=prompts, completions=completions) +[1.0, 0.0] +``` + +#### Example 3: Reward completions based on a reference + +Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948). +This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`. + +```python +import re + +def reward_func(completions, ground_truth, **kwargs): + # Regular expression to capture content inside \boxed{} + matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions] + contents = [match.group(1) if match else "" for match in matches] + # Reward 1 if the content is the same as the ground truth, 0 otherwise + return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)] +``` + +You can test this function as follows: + +```python +>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."] +>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."] +>>> ground_truth = ["2", "5"] +>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth) +[1.0, 0.0] +``` +#### Example 4: Multi-task reward functions + +Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works. + +```python +from datasets import Dataset +from trl import GRPOTrainer + +# Define a dataset that contains both math and coding problems +dataset = Dataset.from_list( + [ + {"prompt": "What is 2+2?", "task": "math"}, + {"prompt": "Write a function that returns the sum of two numbers.", "task": "code"}, + {"prompt": "What is 3*4?", "task": "math"}, + {"prompt": "Write a function that returns the product of two numbers.", "task": "code"}, + ] +) + +# Math-specific reward function +def math_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "math": + # Calculate math-specific reward + correct = check_math_solution(prompt, completion) + reward = 1.0 if correct else -1.0 + rewards.append(reward) + else: + # Return None for non-math tasks + rewards.append(None) + return rewards + +# Coding-specific reward function +def coding_reward_func(prompts, completions, task, **kwargs): + rewards = [] + for prompt, completion, t in zip(prompts, completions, task): + if t == "coding": + # Calculate coding-specific reward + works = test_code_solution(prompt, completion) + reward = 1.0 if works else -1.0 + rewards.append(reward) + else: + # Return None for non-coding tasks + rewards.append(None) + return rewards + +# Use both task-specific reward functions +trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=[math_reward_func, coding_reward_func], + train_dataset=dataset, +) + +trainer.train() +``` + +In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability. + +Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function. + + + +#### Passing the reward function to the trainer + +To use your custom reward function, pass it to the [`GRPOTrainer`] as follows: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=reward_func, + ..., +) +``` + +If you have multiple reward functions, you can pass them as a list: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + reward_funcs=[reward_func1, reward_func2], + ..., +) +``` +and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config. + +Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. + +## GRPOTrainer + +[[autodoc]] GRPOTrainer + +## GRPOConfig + +[[autodoc]] GRPOConfig diff --git a/docs/source/how_to_train.md b/docs/source/how_to_train.md new file mode 100644 index 0000000000000000000000000000000000000000..6ac55079b7e90aa13651f3eb1b253d824bdb5931 --- /dev/null +++ b/docs/source/how_to_train.md @@ -0,0 +1,65 @@ +# Training FAQ + +## What Metrics Should I Look at? + +When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves. + +To address this, we recommend focusing on two key metrics first: + +**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training. +**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces. + +However, there are more metrics that can be useful for debugging, checkout the [logging section](logging). + +## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence? + +When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans. + +However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks. + +
+ +

Figure: Samples without a KL penalty from https://huggingface.co/papers/1909.08593.

+
+ +To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates. + +## What Is the Concern with Negative KL Divergence? + +If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases: + +- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected +- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached + +These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it. + +So how should you generate text for PPO training? Let's have a look! + +## How to generate text for training? + +In order to avoid the KL issues described above we recommend to use the following settings: + +```python +generation_kwargs = { + "min_length": -1, # don't ignore the EOS token (see above) + "top_k": 0.0, # no top-k sampling + "top_p": 1.0, # no nucleus sampling + "do_sample": True, # yes, we want to sample + "pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead + "max_new_tokens": 32, # specify how many tokens you want to generate at most +} +``` + +With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist. + +## How can debug your own use-case? + +Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier: + +- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from. +- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either. +- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that. +- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations. +- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query). + +These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well! diff --git a/docs/source/index.md b/docs/source/index.md new file mode 100644 index 0000000000000000000000000000000000000000..879defe938f1a053a3e3a64649d17fbe4c3b7b9b --- /dev/null +++ b/docs/source/index.md @@ -0,0 +1,82 @@ +
+ +
+ +# TRL - Transformer Reinforcement Learning + +TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more. +The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers). + +You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib). + +## Learn + +Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course). + +## Contents + +The documentation is organized into the following sections: + +- **Getting Started**: installation and quickstart guide. +- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs. +- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc. +- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc. +- **Examples**: example overview, community tutorials, etc. +- **API**: trainers, utils, etc. + +## Blog posts + + diff --git a/docs/source/installation.md b/docs/source/installation.md new file mode 100644 index 0000000000000000000000000000000000000000..8ab4165931b13526e28df4c61e948d4591d9094f --- /dev/null +++ b/docs/source/installation.md @@ -0,0 +1,39 @@ +# Installation +You can install TRL either from PyPI or from source: + +## PyPI +Install the library with pip or [uv](https://docs.astral.sh/uv/): + + + + +uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), . + +```bash +uv pip install trl +``` + + + + +```bash +pip install trl +``` + + + + +## Source +You can also install the latest version from source. First clone the repo and then run the installation with `pip`: + +```bash +git clone https://github.com/huggingface/trl.git +cd trl/ +pip install -e . +``` + +If you want the development install you can replace the pip install with the following: + +```bash +pip install -e ".[dev]" +``` diff --git a/docs/source/iterative_sft_trainer.md b/docs/source/iterative_sft_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..28a90c1790a9cc48cd8bc94ec2891b563ca66947 --- /dev/null +++ b/docs/source/iterative_sft_trainer.md @@ -0,0 +1,139 @@ +# Iterative Trainer + +[![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl) + +Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code. + +## Quickstart + +To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer: + +```python +from trl import IterativeSFTConfig, IterativeSFTTrainer + +# Using a model identifier +trainer = IterativeSFTTrainer( + "facebook/opt-350m", + args=IterativeSFTConfig( + max_length=512, + output_dir="./output", + ), +) + +# Or using a pre-instantiated model +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +trainer = IterativeSFTTrainer( + model, + args=IterativeSFTConfig( + max_length=512, + output_dir="./output", + ), + processing_class=tokenizer, +) +``` + +## Usage + +The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function: + +### Using a list of tensors as input: + +```python +inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, +} + +trainer.step(**inputs) +``` + +### Using a list of strings as input: + +```python +inputs = { + "texts": texts, + "texts_labels": texts_labels, # Optional, defaults to texts +} + +trainer.step(**inputs) +``` + +For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`. + +## Configuration + +The [`IterativeSFTConfig`] class provides several parameters to customize the training: + +```python +from trl import IterativeSFTConfig + +config = IterativeSFTConfig( + # Model initialization parameters + model_init_kwargs={"torch_dtype": "bfloat16"}, + + # Data preprocessing parameters + max_length=512, + truncation_mode="keep_end", + + # Training parameters + output_dir="./output", + learning_rate=2e-5, + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + max_steps=1000, + logging_steps=10, + save_steps=100, + optim="adamw_torch", + report_to="wandb", +) +``` + +### Model Initialization + +You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`: + +```python +config = IterativeSFTConfig( + model_init_kwargs={ + "torch_dtype": "bfloat16", + "device_map": "auto", + "trust_remote_code": True, + } +) +``` + +### Data Preprocessing + +The trainer supports two truncation modes: + +- `keep_end`: Truncates from the start of the sequence +- `keep_start`: Truncates from the end of the sequence + +```python +config = IterativeSFTConfig( + max_length=512, + truncation_mode="keep_end", # or "keep_start" +) +``` + +### Training Optimization + +You can optimize CUDA cache usage for more memory-efficient training: + +```python +config = IterativeSFTConfig( + optimize_device_cache=True, +) +``` + +## IterativeSFTTrainer + +[[autodoc]] IterativeSFTTrainer + +## IterativeSFTConfig + +[[autodoc]] IterativeSFTConfig diff --git a/docs/source/judges.md b/docs/source/judges.md new file mode 100644 index 0000000000000000000000000000000000000000..d3fd1634161e76b34968008fb00998444296ebd9 --- /dev/null +++ b/docs/source/judges.md @@ -0,0 +1,89 @@ +# Judges + + + +TRL Judges is an experimental API which is subject to change at any time. + + + +TRL provides judges to easily compare two completions. + +Make sure to have installed the required dependencies by running: + +```bash +pip install trl[judges] +``` + +## Using the provided judges + +TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub: + +```python +from trl import HfPairwiseJudge + +judge = HfPairwiseJudge() +judge.judge( + prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], + completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]], +) # Outputs: [0, 1] +``` + +## Define your own judge + +To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method. + +As an example, let's define a pairwise judge that prefers shorter completions: + +```python +from trl import BasePairwiseJudge + +class PrefersShorterJudge(BasePairwiseJudge): + def judge(self, prompts, completions, shuffle_order=False): + return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions] +``` + +You can then use this judge as follows: + +```python +judge = PrefersShorterJudge() +judge.judge( + prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], + completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]], +) # Outputs: [0, 1] +``` + +## Provided judges + +### PairRMJudge + +[[autodoc]] PairRMJudge + +### HfPairwiseJudge + +[[autodoc]] HfPairwiseJudge + +### OpenAIPairwiseJudge + +[[autodoc]] OpenAIPairwiseJudge + +### AllTrueJudge + +[[autodoc]] AllTrueJudge + +## Base classes + +### BaseJudge + +[[autodoc]] BaseJudge + +### BaseBinaryJudge + +[[autodoc]] BaseBinaryJudge + +### BaseRankJudge + +[[autodoc]] BaseRankJudge + +### BasePairwiseJudge + +[[autodoc]] BasePairwiseJudge diff --git a/docs/source/kto_trainer.md b/docs/source/kto_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..92f377255b24a03707051294e9402807faccda89 --- /dev/null +++ b/docs/source/kto_trainer.md @@ -0,0 +1,139 @@ +# KTO Trainer + +[![](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl) + +## Overview + +Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela). + + +The abstract from the paper is the following: + +> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive. + +The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs). + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente. + +## Quick start + +This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_kto.py +from datasets import load_dataset +from trl import KTOConfig, KTOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train") + +training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10) +trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_kto.py +``` + +Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kto-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-KTO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-KTO>:
+The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:                                                                                  
+
+Here are some other factors to consider when choosing a programming language for a project:
+
+ 1 JavaScript: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.                                                                   
+ 2 Java: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.                                                                                                                                                            
+ 3 C++: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.                                                                                                                                         
+ 4 Python: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.   
+
+ +## Expected dataset format + +KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones. + +The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate. + +## Example script + +We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py) + +To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command: + +```bash +accelerate launch trl/scripts/kto.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/kto-mix-14k \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-KTO +``` + +## Usage tips + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + + +### Batch size recommendations + +Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor. + +### Learning rate recommendations + +Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results. + +### Imbalanced data + +The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples. +By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3. + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta +- `logps/chosen_sum`: the sum of log probabilities of the chosen completions +- `logps/rejected_sum`: the sum of log probabilities of the rejected completions +- `logits/chosen_sum`: the sum of logits of the chosen completions +- `logits/rejected_sum`: the sum of logits of the rejected completions +- `count/chosen`: the count of chosen samples in a batch +- `count/rejected`: the count of rejected samples in a batch + +## KTOTrainer + +[[autodoc]] KTOTrainer + +## KTOConfig + +[[autodoc]] KTOConfig diff --git a/docs/source/liger_kernel_integration.md b/docs/source/liger_kernel_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..6a2cdf0e7ee1950b81b849b223b72dcd39ca9ccb --- /dev/null +++ b/docs/source/liger_kernel_integration.md @@ -0,0 +1,7 @@ +# Liger Kernel Integration + + + +Section under construction. Feel free to contribute! + + \ No newline at end of file diff --git a/docs/source/logging.md b/docs/source/logging.md new file mode 100644 index 0000000000000000000000000000000000000000..b131436bcb6279098c23fe07dfd2a2d0a86276c0 --- /dev/null +++ b/docs/source/logging.md @@ -0,0 +1,99 @@ +# Logging + +As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging. +By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Weights & Biases (wandb) or TensorBoard. + +Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`): + +```python +# For PPOTrainer +ppo_config = PPOConfig( + # ..., + report_to="wandb" # or "tensorboard" +) + +# For GRPOTrainer +grpc_config = GRPOConfig( + # ..., + report_to="wandb" # or "tensorboard" +) +``` + +If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`). + +## PPO Logging + +Here's a brief explanation for the logged metrics provided in the data: + +* `eps`: Tracks the number of episodes per second. +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. +* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. +* `objective/scores`: The mean scores returned by the reward model / environment. +* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`. +* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes. +* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing. +* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward. +* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function. +* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are. +* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed. +* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes. +* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `lr`: The current learning rate used by the optimizer. +* `episode`: The current episode count in the training process. + +### Crucial values +During training, many values are logged, here are the most important ones: + +1. `objective/scores`: The mean scores returned by the reward model / environment. +1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. +1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty). + +Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables): + +1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well. +1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic. +1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well. +1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy. + +## GRPO Logging + +Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer: + +* `num_tokens`: Total number of input tokens processed during training so far. + +**Completions:** +* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token). +* `completions/min_length`: Minimum length among all generated completions. +* `completions/max_length`: Maximum length among all generated completions. +* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated). +* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token. +* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token. +* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token. + +**Rewards:** +* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used. +* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function. +* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages). +* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages. + +**Policy and Loss Metrics:** +* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero. +* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`): + * `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds. +* If standard GRPOLoss is used (`use_liger_loss: False`): + * `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound). + * `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes. + * `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound). + * `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes. + * `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound. + +### Crucial GRPO values +During GRPO training, monitor these values for insights into performance and stability: + +1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training. +1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability. +1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive. +1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short. +1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources. diff --git a/docs/source/model_utils.md b/docs/source/model_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..64783ad97d8f4178ffdd871bc5c857f7b23785c8 --- /dev/null +++ b/docs/source/model_utils.md @@ -0,0 +1,5 @@ +# Model Utilities + +## get_act_offloading_ctx_manager + +[[autodoc]] models.get_act_offloading_ctx_manager diff --git a/docs/source/models.md b/docs/source/models.md new file mode 100644 index 0000000000000000000000000000000000000000..f96068fc46f160c6d60d3b95712fb277c826f6e9 --- /dev/null +++ b/docs/source/models.md @@ -0,0 +1,28 @@ +# Models + +With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory. + +## PreTrainedModelWrapper + +[[autodoc]] PreTrainedModelWrapper + +## AutoModelForCausalLMWithValueHead + + +[[autodoc]] AutoModelForCausalLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## AutoModelForSeq2SeqLMWithValueHead + +[[autodoc]] AutoModelForSeq2SeqLMWithValueHead + - __init__ + - forward + - generate + - _init_weights + +## create_reference_model + +[[autodoc]] create_reference_model \ No newline at end of file diff --git a/docs/source/multi_adapter_rl.md b/docs/source/multi_adapter_rl.md new file mode 100644 index 0000000000000000000000000000000000000000..42cc9d4e9193a1ce6035083c3d25da9fd1b57194 --- /dev/null +++ b/docs/source/multi_adapter_rl.md @@ -0,0 +1,100 @@ +# Multi Adapter RL (MARL) - a single base model for everything + +Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues. + +## Requirements + +You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning. + +## Summary + +You need to address this approach in three stages that we summarize as follows: + +1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL. +2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py) +3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL") + +Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3. + +## Quickstart + +Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`. +When doing PPO, before passing the model to `PPOTrainer` create your model as follows: + +```python +model_name = "huggyllama/llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, +) + +... +trainer = PPOTrainer( + model=model, + ... +) + +... +``` +Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`. + +```python +rewards = trainer.model.compute_reward_score(**inputs) +``` + +## Advanced usage + +### Control on the adapter name + +If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies. +In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`. + +```python +adapter_name_policy_1 = "policy_1" +rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1) +... +``` + +### Using 4-bit and 8-bit base models + +For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32). +Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`): +```python +model_name = "llama-7b" +rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" + +# PPO adapter +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_name, + peft_config=lora_config, + reward_adapter=rm_adapter_id, + load_in_8bit=True, +) + +... +trainer = PPOTrainer( + model=model, + ... +) +... +``` diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..6d452c4ecbfa85b618d9f742df87f986208b3d9a --- /dev/null +++ b/docs/source/nash_md_trainer.md @@ -0,0 +1,159 @@ +# Nash-MD Trainer + +[![](https://img.shields.io/badge/All_models-Nash--MD-blue)](https://huggingface.co/models?other=nash-md,trl) + +## Overview + +Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Quick start + +This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + + +Below is the script to train the model: + +```python +# train_nash_md.py +from datasets import load_dataset +from trl import NashMDConfig, NashMDTrainer, PairRMJudge +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10) +trainer = NashMDTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_nash_md.py +``` + +Distributed across 8 GPUs, the training takes approximately 3 hours. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-NashMD
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-NashMD>:
+The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
+
+ +## Expected dataset type + +Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) + + trainer = NashMDTrainer( + ... +- judge=judge, ++ reward_model=reward_model, + ) +``` + + + +Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. + + + +### Encourage EOS token generation + +We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]: + +```python +training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = NashMDTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + +## Example script + +We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) + +To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/nash_md.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --logging_steps 25 \ + --output_dir Qwen2.5-0.5B-NashMD-PairRM \ + --warmup_ratio 0.1 \ + --push_to_hub +``` + +## Logged metrics + +The logged metrics are as follows: + +* `loss/kl`: The mean KL divergence between the model and reference data. +* `objective/entropy`: The mean entropy of the model and reference data. +* `loss/score`: The mean reinforce score loss. +* `rewards/chosen`: The mean scores (according to the reward model) of the model completions. +* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions. +* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion. +* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model. +* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the reference completions. +* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. +* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`]. +* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`]. + +## NashMDTrainer + +[[autodoc]] NashMDTrainer + +## NashMDConfig + +[[autodoc]] NashMDConfig diff --git a/docs/source/online_dpo_trainer.md b/docs/source/online_dpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..209b4c75835f18cc34690df05a8fa92c7b7bbac0 --- /dev/null +++ b/docs/source/online_dpo_trainer.md @@ -0,0 +1,278 @@ +# Online DPO Trainer + +[![](https://img.shields.io/badge/All_models-Online_DPO-blue)](https://huggingface.co/models?other=online-dpo,trl) + +## Overview + +Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel. + +The abstract from the paper is the following: + +> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator. + +This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching). + +## Quick start + +This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + + +Below is the script to train the model: + +```python +# train_online_dpo.py +from datasets import load_dataset +from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10) +trainer = OnlineDPOTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_online_dpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online-dpo-qwen2.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-OnlineDPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-OnlineDPO>:
+The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
+
+ +## Expected dataset type + +Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) ++ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward") + + trainer = OnlineDPOTrainer( + ... +- judge=judge, ++ reward_model=reward_model, ++ reward_processing_class=reward_tokenizer, + ... + ) +``` + +### Encourage EOS token generation + +When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]: + +```python +training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + + +## Example script + +We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) + +To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/dpo_online.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --logging_steps 25 \ + --output_dir Qwen2.5-0.5B-Online-DPO-PairRM \ + --warmup_ratio 0.1 \ + --push_to_hub +``` + +## Logged metrics + +The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9) + +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model. +* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up. +* `objective/scores`: The mean scores returned by the reward model. +* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions. +* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions. +* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model. +* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the rejected completions. +* `val/contain_eos_token`: The fraction of completions which contain an EOS token. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`]. + +## Benchmark experiments + +To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + + +``` +# 1B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --logging_steps 20 \ + --save_steps 0.1 \ + --push_to_hub + +# 2.8B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-2.8b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --bf16 \ + --logging_steps 20 \ + --save_steps 0.1 \ + --push_to_hub + +# 6.9B Online DPO experiment +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-6.9b-deduped-tldr-online-dpo \ + --beta 0.1 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 4 \ + --num_train_epochs 3 \ + --max_new_tokens 53 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --bf16 \ + --gradient_checkpointing \ + --logging_steps 20 \ + --save_steps 0.1 \ + --push_to_hub +``` + +Checkpoints and experiment tracking are available at: + +- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0) + + +To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. +For more information on how to use judges, see [Judges](judges). + +```bash +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 33.00% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 41.50% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 62.60% +python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 74.20% +``` + +We can then plot the RLHF scaling chart. + +```python +import matplotlib.pyplot as plt + +results = { + "SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316}, + "online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796}, + "offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701}, +} + + +plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o") +plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o") +plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o") +plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary") +plt.xscale("log") +plt.xlabel("Model size") +plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)") +plt.title("DPO scaling by model size") +plt.legend() +plt.xlim(5e8, 1.2e10) +plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"]) +plt.grid(True, which="both", ls="--", c="0.7") +plt.tight_layout() +plt.show() +``` + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online_dpo_scaling.png) + +The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended. + +## OnlineDPOTrainer + +[[autodoc]] OnlineDPOTrainer + +## OnlineDPOConfig + +[[autodoc]] OnlineDPOConfig diff --git a/docs/source/orpo_trainer.md b/docs/source/orpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..cf5fe8bce02cc5d5c5cc13a947f6d400932ae860 --- /dev/null +++ b/docs/source/orpo_trainer.md @@ -0,0 +1,129 @@ +# ORPO Trainer + +[![](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl) [![](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment) + +## Overview + +Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes). + +The abstract from the paper is the following: + +> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B). + +It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT. + +Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory. + +The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo). + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt). + +## Quick start + +This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_orpo.py +from datasets import load_dataset +from trl import ORPOConfig, ORPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") + +training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO", logging_steps=10) +trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_orpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/orpo-qwen2-reward-margin.png) + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-ORPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-ORPO>:
+It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
+C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.                                                                                          
+
+Here are some other factors to consider when choosing a programming language for a project:
+
+ • Language proficiency: A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.                                     
+ • Ease of use: There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
+ • Code readability: A clear and concise codebase should be easy to read and understand, especially when working with large projects.
+ • Tool and framework support: There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
+ • Accessibility: Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
+ • Version control: As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
+
+
+ +## Expected dataset type + +ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section. + +## Example script + +We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) + +To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command: + +```bash +accelerate launch examples/scripts/orpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-ORPO +``` + +## Usage tips + +### For Mixture of Experts Models: Enabling the auxiliary loss + +MOEs are the most efficient if the load is about equally distributed between experts. +To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. + +This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]). +To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config. + +## Logged metrics + +While training and evaluating we record the following reward metrics: + +- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta +- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta +- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards +- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards +- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses +- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))` +- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses + +## ORPOTrainer + +[[autodoc]] ORPOTrainer + +## ORPOConfig + +[[autodoc]] ORPOConfig diff --git a/docs/source/others.md b/docs/source/others.md new file mode 100644 index 0000000000000000000000000000000000000000..bd89447e7b877f5a24818099510de64ffa772aa0 --- /dev/null +++ b/docs/source/others.md @@ -0,0 +1,9 @@ +# Other + +## profiling_decorator + +[[autodoc]] extras.profiling.profiling_decorator + +## profiling_context + +[[autodoc]] extras.profiling.profiling_context diff --git a/docs/source/peft_integration.md b/docs/source/peft_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..01bef1e9e0dd38ad40f8ea533b4384675e0cc5a4 --- /dev/null +++ b/docs/source/peft_integration.md @@ -0,0 +1,144 @@ +# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA) + +The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported. +For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685). + +Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + +| File | Task | Description | Colab link | +|---|---| --- | +| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | | +| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | | +| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | | + +## Installation +Note: peft is in active development, so we install directly from their Github page. +Peft also relies on the latest version of transformers. + +```bash +pip install trl[peft] +pip install bitsandbytes loralib +pip install git+https://github.com/huggingface/transformers.git@main +#optional: wandb +pip install wandb +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + +## How to use it? + +Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model. + +```python +from peft import LoraConfig +from trl import AutoModelForCausalLMWithValueHead + +model_id = "edbeeching/gpt-neo-125M-imdb" +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLMWithValueHead.from_pretrained( + model_id, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=True, + peft_config=lora_config, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` + + +## Launch scripts + +The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands: + +```bash +accelerate config # will prompt you to define the training configuration +accelerate launch examples/scripts/ppo.py --use_peft # launch`es training +``` + +## Using `trl` + `peft` and Data Parallelism + +You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows: +```python +from peft import LoraConfig +... + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, +) +``` +And if you want to load your model in 8bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_8bit=True, +) +``` +... or in 4bit precision: +```python +pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + peft_config=lora_config, + load_in_4bit=True, +) +``` +Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`. + +## Naive pipeline parallelism (NPP) for large models (>60B models) + +The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs. +This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models. + +
+ +
+ +### How to use NPP? + +Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model. + +Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`. + +### Launch scripts + +Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet. + +```bash +python PATH_TO_SCRIPT +``` + +## Fine-tuning Llama-2 model + +You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB): + +```bash +python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2 +``` diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..359d482931633621fc55724ece4f67d22c8d9d01 --- /dev/null +++ b/docs/source/ppo_trainer.md @@ -0,0 +1,239 @@ +# PPO Trainer + +[![](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) + +TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347). + +References: +- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences) +- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback) +- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo) +- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031) + +## Get started + +To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model. + +```bash +python examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ + --missing_eos_penalty 1.0 +``` + + +## Explanation of the logged metrics + +The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + +* `eps`: Tracks the number of episodes per second. +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. +* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. +* `objective/scores`: The mean scores returned by the reward model / environment. +* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`. +* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes. +* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing. +* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward. +* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function. +* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are. +* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed. +* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes. +* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `lr`: lr: The current learning rate used by the optimizer. +* `episode`: episode: The current episode count in the training process. + + +## Cookbook + +* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. +* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. +* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. +* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. +* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. + + +## What is my model doing exactly? + +To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif) + + +In the logs the sampled generations look like + +``` +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ +┃ query ┃ model response ┃ score ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ +│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │ +│ │ I don't know how to get rid of │ │ +│ TITLE: How do you get someone │ those feelings. I'm │ │ +│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │ +│ │ │ │ +│ POST: Hi, │ │ │ +│ I'm 22, and I have been with my │ │ │ +│ girlfriend for 5 years now. We │ │ │ +│ recently moved together. We've │ │ │ +│ always loved each other │ │ │ +│ intensely. │ │ │ +│ │ │ │ +│ Problem, I recently started to │ │ │ +│ have feelings for an other │ │ │ +│ person (a friend). This person │ │ │ +│ has had a boyfriend for now 3 │ │ │ +│ years, and has absolutely no │ │ │ +│ ideas. Those feelings were so │ │ │ +│ strong, it was hard to hide │ │ │ +│ them. After 2 months of me │ │ │ +│ being distant and really sad, │ │ │ +│ my girlfriend forced me to say │ │ │ +│ what was bothering me. I'm not │ │ │ +│ a good liar, and now she knows. │ │ │ +│ │ │ │ +│ We decided to give us a week │ │ │ +│ alone, I went to my parents. │ │ │ +│ │ │ │ +│ Now, I'm completely lost. I │ │ │ +│ keep on thinking about this │ │ │ +│ person, and I hate that. I │ │ │ +│ would like for those feelings │ │ │ +│ to go away, to leave me alone. │ │ │ +│ But I can't. │ │ │ +│ │ │ │ +│ What do I do? It's been 3 │ │ │ +│ months now, and I'm just │ │ │ +│ desperate. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +├─────────────────────────────────┼─────────────────────────────────┼──────────┤ +│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │ +│ │ TV. I blasted Gangnam Style on │ │ +│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │ +│ with a loud TV. │ up as high as it could │ │ +│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │ +│ POST: She was in her living │ │ │ +│ room, watching TV. This was at │ │ │ +│ about 8:30 in the morning, and │ │ │ +│ she was exercising. She turned │ │ │ +│ the TV up extra loud to hear it │ │ │ +│ over her excercycle, and woke │ │ │ +│ me up. I went in there asking │ │ │ +│ for her to turn it down. She │ │ │ +│ said she didn't have to; I │ │ │ +│ explained that I always used │ │ │ +│ headphones so she didn't have │ │ │ +│ to deal with my noise and that │ │ │ +│ she should give me a little │ │ │ +│ more respect, given that I paid │ │ │ +│ rent at the time. │ │ │ +│ │ │ │ +│ She disagreed. I went back to │ │ │ +│ my room, rather pissed off at │ │ │ +│ the lack of equality. I had no │ │ │ +│ lock on my door; but I had a │ │ │ +│ dresser right next to it, so I │ │ │ +│ pulled one of the drawers out │ │ │ +│ enough so that it caused the │ │ │ +│ door to not be openable. Then, │ │ │ +│ I turned my speakers up really │ │ │ +│ loud and blasted Gangnam Style │ │ │ +│ on repeat, with the bass │ │ │ +│ cranked up as high as it could │ │ │ +│ go. │ │ │ +│ │ │ │ +│ If you hate Gangnam Style for │ │ │ +│ being overplayed, you will see │ │ │ +│ why I chose that particular │ │ │ +│ song. I personally don't mind │ │ │ +│ it. But here's the thing about │ │ │ +│ my bass; it vibrates the walls, │ │ │ +│ making one hell of a lot of │ │ │ +│ noise. Needless to say, my mom │ │ │ +│ was not pleased and shut off │ │ │ +│ the internet. But it was oh so │ │ │ +│ worth it. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +└─────────────────────────────────┴─────────────────────────────────┴──────────┘ +``` + +## Implementation details + +This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +## Benchmark experiments + +To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +``` +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/ppo/ppo_tldr.py \ + --output_dir models/minimal/ppo_tldr \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --local_rollout_forward_batch_size 16 \ + --missing_eos_penalty 1.0 \ + --stop_token eos +``` + +Checkpoints and experiment tracking are available at: + +- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35) + +To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. +For more information on how to use judges, see [Judges](judges). + +```bash +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 33.00% +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 64.70% +``` + +The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended. + +Metrics: + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2.png) + + +```bash +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \ + "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \ + --env-ids models/minimal/ppo_tldr \ + --pc.ncols 4 \ + --pc.ncols-legend 1 \ + --pc.xlabel "Episode" \ + --output-filename benchmark/trl/pr-1540/ppo \ + --scan-history +``` + +## PPOTrainer + +[[autodoc]] PPOTrainer + +## PPOConfig + +[[autodoc]] PPOConfig diff --git a/docs/source/prm_trainer.md b/docs/source/prm_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..51813ca8d2030b2761acd4c0b3e54dbfd90c3184 --- /dev/null +++ b/docs/source/prm_trainer.md @@ -0,0 +1,125 @@ +# PRM Trainer + +[![](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl) + + + +PRM Trainer is an experimental API which is subject to change at any time. + + + +## Overview + +Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins. + +The abstract from the paper is the following: + +> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions. + +This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss). + + +## Quick start + +This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here: + + + +Below is the script to train the model: + +```python +# train_prm.py +from datasets import load_dataset +from trl import PRMConfig, PRMTrainer +from transformers import AutoModelForTokenClassification, AutoTokenizer + +model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2) +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B") +train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]") + +training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10) +trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_prm.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script. + + +```python +from datasets import load_dataset +from transformers import pipeline + +pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd") +dataset = load_dataset("trl-lib/math_shepherd") +example = { + "prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?", + "completions": [ + "Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.", + "Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.", + "Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20", + ], + "labels": [True, False, False], +} + + +separator = "\n" # It's important to use the same separator as the one used during training + +for idx in range(1, len(example["completions"]) + 1): + steps = example["completions"][0:idx] + text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps + pred_entity = pipe(text)[-1]["entity"] + pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity] + label = example["labels"][idx - 1] + print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}") +``` + +```text +Step 1 Predicted: True Label: True +Step 2 Predicted: False Label: False +Step 3 Predicted: False Label: False +``` + +It's a win! + +## Expected dataset type + +PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision). +The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step. + +The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format. + +## Example script + +We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) + +To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command: + +```bash +accelerate launch examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/math_shepherd \ + --num_train_epochs 1 \ + --logging_steps 25 \ + --output_dir Qwen2-0.5B-Reward-Math-Sheperd +``` + +## PRMTrainer + +[[autodoc]] PRMTrainer + +## PRMConfig + +[[autodoc]] PRMConfig diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md new file mode 100644 index 0000000000000000000000000000000000000000..f310a101d8f2a4ecf3b561eaa41335b53aa893b4 --- /dev/null +++ b/docs/source/quickstart.md @@ -0,0 +1,88 @@ +# Quickstart + +## How does it work? + +Fine-tuning a language model via PPO consists of roughly three steps: + +1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence. +2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value. +3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. + +The full process is illustrated in the following figure: + + +## Minimal example + +The following code illustrates the steps above. + +```python +# 0. imports +import torch +from transformers import GPT2Tokenizer + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer + + +# 1. load a pretrained model +model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +tokenizer.pad_token = tokenizer.eos_token + +# 2. initialize trainer +ppo_config = {"mini_batch_size": 1, "batch_size": 1} +config = PPOConfig(**ppo_config) +ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer) + +# 3. encode a query +query_txt = "This morning I went to the " +query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) + +# 4. generate model response +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "max_new_tokens": 20, +} +response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) +response_txt = tokenizer.decode(response_tensor[0]) + +# 5. define a reward for response +# (this could be any reward such as human feedback or output from another model) +reward = [torch.tensor(1.0, device=model.pretrained_model.device)] + +# 6. train model with ppo +train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) +``` + +In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section. + +## How to use a trained model + +After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`. +```python + +# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead` + +# push the model on the Hub +model.push_to_hub("my-fine-tuned-model-ppo") + +# or save it locally +model.save_pretrained("my-fine-tuned-model-ppo") + +# load the model from the Hub +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo") +``` + +You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training. + +```python +from trl.model import AutoModelForCausalLMWithValueHead + +model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo") +``` diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md new file mode 100644 index 0000000000000000000000000000000000000000..ded2c13b58a392196e7fff37ab15197f67b9d0de --- /dev/null +++ b/docs/source/reducing_memory_usage.md @@ -0,0 +1,261 @@ +# Reducing Memory Usage + + + +Section under construction. Feel free to contribute! + + + +## Truncation + +Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short. + +
+ Truncation prompt completion +
+ +To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case. + + + + +DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence. + +
+ Truncation prompt completion +
+ +To set the truncation parameters, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., max_prompt_length=..., max_length=...) +``` + +You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible. + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., max_completion_length=...) +``` + +
+ + +SFT truncation is applied to the input sequence via the `max_length` parameter. + +
+ Truncation input ids +
+ +To set the truncation parameter, use the following code snippet: + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., max_length=...) +``` + +
+
+ +## Packing + + + +This technique applies only to SFT. + + + + +[Truncation](#truncation) has several drawbacks: +1. **Loss of information**: Key data at the end of a sequence may be discarded. +2. **Choosing truncation length**: Too short loses data; too long undermines efficiency. + +Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths. + +
+ Packing +
+ +Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` and in the [`SFTConfig`]. + + + +In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`. + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., packing=True, max_length=512) +``` + + + +Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230). + + + +## Liger for reducing peak memory usage + +> [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. + +For more information, see [Liger Kernel Integration](liger_kernel_integration) + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., use_liger_loss=True) +``` + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_liger_loss=True) +``` + + + + +To use Liger for reducing peak memory usage, use the following code snippet: + +```python +from trl import KTOConfig + +training_args = KTOConfig(..., use_liger_loss=True) +``` + + + + +## Padding-free + +Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact. + +
+ Padding-free batching +
+ + + +It's highly recommended to use padding-free batching with **Flash Attention 2**. Otherwise, you may encounter batch contamination issues. + + + + + + +```python +from trl import DPOConfig + +training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +``` + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"}) +``` + + + + +## Activation offloading + +Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time. + +To enable activation offloading in your SFT training configuration: + + + + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., activation_offloading=True) +``` + + + + + + +When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works: + +```python +# When using activation offloading with a model that uses Liger kernels: +from trl import SFTConfig + +training_args = SFTConfig( + activation_offloading=True, + use_liger_kernel=False, # Disable Liger cross entropy + # Other parameters... +) +``` + + +Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers. + +## Disabling model gathering for generation in online methods + +When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204). + +If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import PPOConfig + +training_args = PPOConfig(..., ds3_gather_for_generation=False) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig(..., ds3_gather_for_generation=False) +``` + + + + +This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds. diff --git a/docs/source/reward_trainer.md b/docs/source/reward_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..09c2ac863c4f222620f6aa37ff45f12e278ca394 --- /dev/null +++ b/docs/source/reward_trainer.md @@ -0,0 +1,90 @@ +# Reward Modeling + +[![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) + +TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. + +Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). + +## Expected dataset type + +The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). +The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. + +## Using the `RewardTrainer` + +After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. +You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. + +### Leveraging 🤗 PEFT to train a reward model + +Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! + +```python +from peft import LoraConfig, TaskType +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from trl import RewardTrainer, RewardConfig + +model = AutoModelForSequenceClassification.from_pretrained("gpt2") +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +... + +trainer = RewardTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() + +``` + +### Adding a margin to the loss + +As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. + +```python +def add_margin(row): + # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin + return {'margin': row['score_chosen'] - row['score_rejected']} + +dataset = dataset.map(add_margin) +``` + +### Centering rewards + +In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it. + +[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs: + +$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$ + +This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`). + +```python +training_args = RewardConfig( + center_rewards_coefficient=0.01, + ... +) +``` + +For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932). + +## RewardTrainer + +[[autodoc]] RewardTrainer + +## RewardConfig + +[[autodoc]] RewardConfig diff --git a/docs/source/rewards.md b/docs/source/rewards.md new file mode 100644 index 0000000000000000000000000000000000000000..30c56d25e7f781f90e29c76ff99c46af6164bb91 --- /dev/null +++ b/docs/source/rewards.md @@ -0,0 +1,9 @@ +# Reward Functions + +This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`]. + +## Format rewards + +### think_format_reward + +[[autodoc]] rewards.think_format_reward diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..5ad5eca53c94a067b9a718424d11422d144e4463 --- /dev/null +++ b/docs/source/rloo_trainer.md @@ -0,0 +1,290 @@ +# RLOO Trainer + +[![](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl) + +TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, whereas PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL. + +References: +- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) +- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123) +- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences) +- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback) +- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo) +- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031) + +## Get started + +To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model. + +```bash +python examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --output_dir models/minimal/rloo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --missing_eos_penalty 1.0 +``` + + +## Explanation of the logged metrics + +The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34) + + + +* `eps`: Tracks the number of episodes per second. +* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy. +* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy. +* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence. +* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`. +* `objective/scores`: The mean scores returned by the reward model / environment. +* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`. +* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes. +* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing. +* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function. +* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are. +* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed. +* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes. +* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses. +* `lr`: lr: The current learning rate used by the optimizer. +* `episode`: episode: The current global step or episode count in the training process. + + +## Cookbook + +* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up. +* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it. +* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint. +* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`. +* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions. + + +## What is my model doing exactly? + +To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations. + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif) + + +In the logs the sampled generations look like + +``` +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ +┃ query ┃ model response ┃ score ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ +│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │ +│ │ I don't know how to get rid of │ │ +│ TITLE: How do you get someone │ those feelings. I'm │ │ +│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │ +│ │ │ │ +│ POST: Hi, │ │ │ +│ I'm 22, and I have been with my │ │ │ +│ girlfriend for 5 years now. We │ │ │ +│ recently moved together. We've │ │ │ +│ always loved each other │ │ │ +│ intensely. │ │ │ +│ │ │ │ +│ Problem, I recently started to │ │ │ +│ have feelings for an other │ │ │ +│ person (a friend). This person │ │ │ +│ has had a boyfriend for now 3 │ │ │ +│ years, and has absolutely no │ │ │ +│ ideas. Those feelings were so │ │ │ +│ strong, it was hard to hide │ │ │ +│ them. After 2 months of me │ │ │ +│ being distant and really sad, │ │ │ +│ my girlfriend forced me to say │ │ │ +│ what was bothering me. I'm not │ │ │ +│ a good liar, and now she knows. │ │ │ +│ │ │ │ +│ We decided to give us a week │ │ │ +│ alone, I went to my parents. │ │ │ +│ │ │ │ +│ Now, I'm completely lost. I │ │ │ +│ keep on thinking about this │ │ │ +│ person, and I hate that. I │ │ │ +│ would like for those feelings │ │ │ +│ to go away, to leave me alone. │ │ │ +│ But I can't. │ │ │ +│ │ │ │ +│ What do I do? It's been 3 │ │ │ +│ months now, and I'm just │ │ │ +│ desperate. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +├─────────────────────────────────┼─────────────────────────────────┼──────────┤ +│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │ +│ │ TV. I blasted Gangnam Style on │ │ +│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │ +│ with a loud TV. │ up as high as it could │ │ +│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │ +│ POST: She was in her living │ │ │ +│ room, watching TV. This was at │ │ │ +│ about 8:30 in the morning, and │ │ │ +│ she was exercising. She turned │ │ │ +│ the TV up extra loud to hear it │ │ │ +│ over her excercycle, and woke │ │ │ +│ me up. I went in there asking │ │ │ +│ for her to turn it down. She │ │ │ +│ said she didn't have to; I │ │ │ +│ explained that I always used │ │ │ +│ headphones so she didn't have │ │ │ +│ to deal with my noise and that │ │ │ +│ she should give me a little │ │ │ +│ more respect, given that I paid │ │ │ +│ rent at the time. │ │ │ +│ │ │ │ +│ She disagreed. I went back to │ │ │ +│ my room, rather pissed off at │ │ │ +│ the lack of equality. I had no │ │ │ +│ lock on my door; but I had a │ │ │ +│ dresser right next to it, so I │ │ │ +│ pulled one of the drawers out │ │ │ +│ enough so that it caused the │ │ │ +│ door to not be openable. Then, │ │ │ +│ I turned my speakers up really │ │ │ +│ loud and blasted Gangnam Style │ │ │ +│ on repeat, with the bass │ │ │ +│ cranked up as high as it could │ │ │ +│ go. │ │ │ +│ │ │ │ +│ If you hate Gangnam Style for │ │ │ +│ being overplayed, you will see │ │ │ +│ why I chose that particular │ │ │ +│ song. I personally don't mind │ │ │ +│ it. But here's the thing about │ │ │ +│ my bass; it vibrates the walls, │ │ │ +│ making one hell of a lot of │ │ │ +│ noise. Needless to say, my mom │ │ │ +│ was not pleased and shut off │ │ │ +│ the internet. But it was oh so │ │ │ +│ worth it. │ │ │ +│ │ │ │ +│ TL;DR: │ │ │ +└─────────────────────────────────┴─────────────────────────────────┴──────────┘ +``` + +## Implementation details + +The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + + +Below is a vectorized advantage calculation for RLOO: + +```python +def test_rloo_reward(): + local_batch_size = 3 + rloo_k = 4 + rlhf_reward = torch.tensor([ + 1, 2, 3, # first rlhf reward for three prompts + 2, 3, 4, # second rlhf reward for three prompts + 5, 6, 7, # third rlhf reward for three prompts + 8, 9, 10, # fourth rlhf reward for three prompts + ]).float() # here we have 3 prompts which have 4 completions each + + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) + advantages = torch.zeros_like(rlhf_reward) + for i in range(0, len(advantages), local_batch_size): + other_response_rlhf_rewards = [] + for j in range(0, len(advantages), local_batch_size): + if i != j: + other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size]) + advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0) + + assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6 # First rlhf reward for the first prompt + assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6 # Third rlhf reward for the second prompt + + # Vectorized implementation + rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size) + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) + vec_advantages = rlhf_reward - baseline + torch.testing.assert_close(vec_advantages.flatten(), advantages) +``` + +## Benchmark experiments + +To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031). + +``` +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + --output_dir models/minimal/rloo_tldr \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ + --num_ppo_epochs 2 \ + --num_mini_batches 2 \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 16 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --local_rollout_forward_batch_size 16 \ + --missing_eos_penalty 1.0 \ + --stop_token eos \ + --kl_coef 0.03 +``` + +Checkpoints and experiment tracking are available at: + +- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/rloo_tldr) +- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/u2sqci34) + + +To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. +For more information on how to use judges, see [Judges](judges). + +```bash +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 33.00% +$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 51.20% +``` + +The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended. + + +Metrics: + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/rloo.png) + + +```bash +# pip install openrlbenchmark==0.2.1a5 +# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation +# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \ + "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \ + --env-ids models/minimal/rloo_tldr \ + --pc.ncols 4 \ + --pc.ncols-legend 1 \ + --pc.xlabel "Episode" \ + --output-filename benchmark/trl/pr-1540/rloo \ + --scan-history +``` + +## Reinforce++ + +The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jian Hu suggests several optimization tricks to enhance performance and stability of RLHF. They include: + +- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion +- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process +- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process +- Using token-level KL penalty that is defined as equation (1) of the report vs. sequence-level KL penalty (default) + +These options are available via the appropriate arguments in the [`RLOOConfig`] class. + + +## RLOOTrainer + +[[autodoc]] RLOOTrainer + +## RLOOConfig + +[[autodoc]] RLOOConfig diff --git a/docs/source/script_utils.md b/docs/source/script_utils.md new file mode 100644 index 0000000000000000000000000000000000000000..aba81bf9f3a0e0ac0f75aea6c2890bcc12781fc0 --- /dev/null +++ b/docs/source/script_utils.md @@ -0,0 +1,12 @@ +# Scripts Utilities + +## ScriptArguments + +[[autodoc]] ScriptArguments + +## TrlParser + +[[autodoc]] TrlParser + - parse_args_and_config + - parse_args_into_dataclasses + - set_defaults_with_config diff --git a/docs/source/sentiment_tuning.md b/docs/source/sentiment_tuning.md new file mode 100644 index 0000000000000000000000000000000000000000..0637cb7ec312c37f327ae4ca031fdd7231799717 --- /dev/null +++ b/docs/source/sentiment_tuning.md @@ -0,0 +1,36 @@ +# Sentiment Tuning Examples + +The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`). + +Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): + + + +| File | Description | +|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| +| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | +| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | +| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. + + + +## Usage + +```bash +# 1. run directly +python examples/scripts/ppo.py +# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed) +accelerate config # will prompt you to define the training configuration +accelerate launch examples/scripts/ppo.py # launches training +# 3. get help text and documentation +python examples/scripts/ppo.py --help +# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16 +python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16 +``` + +Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). + + +## Few notes on multi-GPU + +To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`. \ No newline at end of file diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..9635b06db1c05e97ffeecf86eb8d8b4d54f7e37a --- /dev/null +++ b/docs/source/sft_trainer.md @@ -0,0 +1,670 @@ +# Supervised Fine-tuning Trainer + +[![](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl) [![](https://img.shields.io/badge/smol_course-Chapter_1-yellow)](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning) + +Supervised fine-tuning (SFT) is the most common step in post-training foundation models, and also one of the most effective. In TRL, we provide a simple API to train models with SFT in a few lines of code; for a complete training script, check out [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py). + +## Quickstart + +If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model. +The following code-snippet takes care of all the data pre-processing and training for you: + +```python +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer + +dataset = load_dataset("stanfordnlp/imdb", split="train") + +training_args = SFTConfig( + max_length=512, + output_dir="/tmp", +) +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + args=training_args, +) +trainer.train() +``` +Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`. + +You can also construct a model outside of the trainer and pass it as follows: + +```python +from transformers import AutoModelForCausalLM +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer + +dataset = load_dataset("stanfordnlp/imdb", split="train") + +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") + +training_args = SFTConfig(output_dir="/tmp") + +trainer = SFTTrainer( + model, + train_dataset=dataset, + args=training_args, +) + +trainer.train() +``` + +The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults, pass in your modification to the `SFTConfig` constructor and pass it to the trainer via the `args` argument. + +## Advanced usage + +### Train on completions only + +To train on completions only, simply use a [prompt-completion](dataset_formats#prompt-completion) dataset. In this mode, loss is computed solely on the completion part. + +If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](dataset_formats#from-prompt-completion-to-language-modeling-dataset) format. + +### Add Special Tokens for Chat Format + +Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment. +The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function: +- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation. +- Resizes the model’s embedding layer to accommodate the new tokens. +- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI. +- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl) + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import setup_chat_format + +# Load model and tokenizer +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") +tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + +# Set up the chat format with the default 'chatml' format +model, tokenizer = setup_chat_format(model, tokenizer) +``` + +> [!WARNING] +> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`. + +With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning. + +### Dataset format support + +The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported: +* conversational format +```json +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]} +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]} +{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]} +``` +* instruction format +```json +{"prompt": "", "completion": ""} +{"prompt": "", "completion": ""} +{"prompt": "", "completion": ""} +``` + +If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method. + + +```python +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer + +... + +# load jsonl dataset +dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train") +# load dataset from the HuggingFace Hub +dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train") + +... + +training_args = SFTConfig(packing=True) +trainer = SFTTrainer( + "facebook/opt-350m", + args=training_args, + train_dataset=dataset, +) +``` + +If the dataset is not in one of those formats, you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look. + + +### Format your input prompts + +For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response. +This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows: +```bash +Below is an instruction ... + +### Instruction +{prompt} + +### Response: +{completion} +``` +Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run: +```python +... +def formatting_prompts_func(example): + return f"### Question: {example['question']}\n ### Answer: {example['answer']}" + + +trainer = SFTTrainer( + model, + args=training_args, + train_dataset=dataset, + formatting_func=formatting_prompt_func, +) + +trainer.train() +``` +To properly format your input, make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on the alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763) + +### Packing dataset + +[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor. + +```python +... +training_args = SFTConfig(packing=True) + +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + args=training_args +) + +trainer.train() +``` + +Note that if you use a packed dataset and if you pass `max_steps` in the training arguments, you will probably train your models for more than a few epochs, depending on the way you have configured the packed dataset and the training protocol. Double-check that you know and understand what you are doing. +If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method. + +#### Customize your prompts using packed dataset + +If your dataset has several fields that you want to combine, for example, if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example: + +```python +def formatting_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + +training_args = SFTConfig(packing=True) +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + args=training_args, + formatting_func=formatting_func +) + +trainer.train() +``` + +### Control over the pretrained model + +You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to + +```python +model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16) + +... + +training_args = SFTConfig( + model_init_kwargs={ + "torch_dtype": "bfloat16", + }, + output_dir="/tmp", +) +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + args=training_args, +) + +trainer.train() +``` +Note that all keyword arguments of `from_pretrained()` are supported. + +### Training adapters + +We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model. + +```python +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer +from peft import LoraConfig + +dataset = load_dataset("trl-lib/Capybara", split="train") + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + target_modules="all-linear", + modules_to_save=["lm_head", "embed_token"], + task_type="CAUSAL_LM", +) + +trainer = SFTTrainer( + "Qwen/Qwen2.5-0.5B", + train_dataset=dataset, + args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"), + peft_config=peft_config +) + +trainer.train() +``` + +> [!WARNING] +> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsensical generations. If the chat template doesn't contain special tokens (e.g., Alpaca), then the `modules_to_save` argument can be ignored or set to `None`. + + +You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed. + +### Training adapters with base 8 bit models + +For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example: + +```python +... + +peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = AutoModelForCausalLM.from_pretrained( + "EleutherAI/gpt-neo-125m", + load_in_8bit=True, + device_map="auto", +) + +trainer = SFTTrainer( + model, + train_dataset=dataset, + args=SFTConfig(), + peft_config=peft_config, +) + +trainer.train() +``` + +## Using Flash Attention and Flash Attention 2 + +You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code. +First, to make sure you have all the latest features from transformers, install transformers from source + +```bash +pip install -U git+https://github.com/huggingface/transformers.git +``` + +Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision) +Note also both features are perfectly compatible with other tools such as quantization. + +### Using Flash-Attention 1 + +For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package: + +```bash +pip install -U optimum +``` + +Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager: + +```diff +... + ++ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + trainer.train() +``` + +Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore, you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration. + +Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB. + +| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step | +| ---------------- | ----------------- | ----------- | ---------- | ---------------------- | +| ✓ | facebook/opt-350m | 2048 | 8 | ~59.1s | +| | facebook/opt-350m | 2048 | 8 | **OOM** | +| ✓ | facebook/opt-350m | 2048 | 4 | ~30.3s | +| | facebook/opt-350m | 2048 | 4 | ~148.9s | + +### Using Flash Attention-2 + +To use Flash Attention 2, first install the latest `flash-attn` package: + +```bash +pip install -U flash-attn +``` + +And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`: + +```python +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + attn_implementation="flash_attention_2" +) +``` + +If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device. +After loading your model, you can either train it as it is or attach adapters and train adapters on it in case your model is quantized. + +In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens. + + +### Using the model creation utility + +We included a utility function to create your model. + +[[autodoc]] ModelConfig + +```python +from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config +model_args = ModelConfig( + model_name_or_path="facebook/opt-350m" + attn_implementation=None, # or "flash_attention_2" +) +torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ["auto", None] + else getattr(torch, model_args.torch_dtype) +) +quantization_config = get_quantization_config(model_args) +model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, +) +model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) +trainer = SFTTrainer( + ..., + model=model_args.model_name_or_path, + peft_config=get_peft_config(model_args), +) +``` + +### Enhance the model's performance using NEFTune + +NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. It consists of adding noise to the embedding vectors during training. According to the abstract of the paper: + +> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF, such as LLaMA-2-Chat, benefit from additional training with NEFTune. + +
+ +
+ +To use it in `SFTTrainer`, simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to revert to the original behaviour of the embedding layer. + +```python +from datasets import load_dataset +from trl import SFTConfig, SFTTrainer + +dataset = load_dataset("stanfordnlp/imdb", split="train") + +training_args = SFTConfig( + neftune_noise_alpha=5, +) +trainer = SFTTrainer( + "facebook/opt-350m", + train_dataset=dataset, + args=training_args, +) +trainer.train() +``` + +We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench. + +
+ +
+ +Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains. + +### Accelerate fine-tuning 2x using `unsloth` + +You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently, `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek, etc) and Mistral architectures. Some benchmarks on 1x A100 listed below: + +| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved | +| --------------- | --------- | --- | --------------------- | --------- | ------------ | +| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% | +| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% | +| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% | +| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% | + +First, install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows: + +```python +import torch +from trl import SFTConfig, SFTTrainer +from unsloth import FastLanguageModel + +max_length = 2048 # Supports automatic RoPE Scaling, so choose any number + +# Load model +model, tokenizer = FastLanguageModel.from_pretrained( + model_name="unsloth/mistral-7b", + max_seq_length=max_length, + dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+ + load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False + # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf +) + +# Do model patching and add fast LoRA weights +model = FastLanguageModel.get_peft_model( + model, + r=16, + target_modules=[ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", + ], + lora_alpha=16, + lora_dropout=0, # Dropout = 0 is currently optimized + bias="none", # Bias = "none" is currently optimized + use_gradient_checkpointing=True, + random_state=3407, +) + +training_args = SFTConfig(output_dir="./output", max_length=max_length) + +trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset, +) +trainer.train() +``` + +The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth). + +## Liger-Kernel: Increase 20% throughput and reduce 60% memory for multi-GPU training + +[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). + +With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance. + +| Speed Up | Memory Reduction | +|--------------------------|-------------------------| +| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | + + +1. To use Liger-Kernel in [`SFTTrainer`], first install it by: + +```bash +pip install liger-kernel +``` + +2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed! + +```python +training_args = SFTConfig( + use_liger_kernel=True, + ... +) +``` + +To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/). + +## Best practices + +Pay attention to the following best practices when training a model with that trainer: + +- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training. +- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it. +- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it. +- If you create a model outside the trainer, make sure not to pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method. + +## Multi-GPU Training + +Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work, you must also check the following: +- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969) +- Ensure that the model is placed on the correct device: +```python +from accelerate import PartialState +device_string = PartialState().process_index +model = AutoModelForCausalLM.from_pretrained( + ... + device_map={'':device_string} +) +``` + +## GPTQ Conversion + +You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format. + +## Extending `SFTTrainer` for Vision Language Models + +`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset. + +### Preparing the Data + +The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images: + +```python +images = ["obama.png"] +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Who is this?"}, + {"type": "image"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Barack Obama"} + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is he famous for?"} + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "He is the 44th President of the United States."} + ] + } +] +``` + +To illustrate how this data format will be processed using the LLaVA model, you can use the following code: + +```python +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") +print(processor.apply_chat_template(messages, tokenize=False)) +``` + +The output will be formatted as follows: + +```txt +Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States. +``` + + + + +### A custom collator for processing multi-modal data + +Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator: + +```python +def collate_fn(examples): + # Get the texts and images, and apply the chat template + texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] + images = [example["images"][0] for example in examples] + + # Tokenize the texts and process the images + batch = processor(texts, images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch +``` + +We can verify that the collator works as expected by running the following code: + +```python +from datasets import load_dataset + +dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train") +examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example +collated_data = collate_fn(examples) +print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels']) +``` + +### Training the vision-language model + +Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`. + +```python +training_args.remove_unused_columns = False +training_args.dataset_kwargs = {"skip_prepare_dataset": True} + +trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=train_dataset, + processing_class=processor.tokenizer, +) +``` + +A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py). + +- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s) +- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf) + +## SFTTrainer + +[[autodoc]] SFTTrainer + +## SFTConfig + +[[autodoc]] SFTConfig + +## Datasets + +In the SFTTrainer, we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled. + +Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to reuse it directly. diff --git a/docs/source/speeding_up_training.md b/docs/source/speeding_up_training.md new file mode 100644 index 0000000000000000000000000000000000000000..b4582ad8039f6d57ac66ea2617388629fef32299 --- /dev/null +++ b/docs/source/speeding_up_training.md @@ -0,0 +1,73 @@ +# Speeding Up Training + + + +Section under construction. Feel free to contribute! + + + +## vLLM for fast generation in online methods + +Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. +To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. + +To use [vLLM](https://github.com/vllm-project/vllm), first install it using: + +```bash +pip install vllm +``` + +or + +```bash +pip install "trl[vllm]" +``` + + + + +Then, enable it by passing `use_vllm=True` in the training arguments. + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig(..., use_vllm=True) +``` + + + + +First, start a vLLM server by running: + +```bash +trl vllm-serve --model +``` + +Then, run the training script and pass `use_vllm=True` in the training arguments. + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration). + + + +When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`. + +Set GPUs **0-3** for vLLM generation: +```sh +CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model +``` + +And GPUs **4-7** for training: +```sh +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +``` + + + + + diff --git a/docs/source/training_vlm_sft.md b/docs/source/training_vlm_sft.md new file mode 100644 index 0000000000000000000000000000000000000000..9ab69e974ef7bbe7ecd7fa804a628aa170725bf1 --- /dev/null +++ b/docs/source/training_vlm_sft.md @@ -0,0 +1,381 @@ +# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) + +![VLM SFT training procedure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure.png) + +## Overview + +This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases: + +- **Single Image + Text** +- **Multi-Image + Text** + +This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly. + +We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets. + +## Understanding the Datasets + +To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task. + +### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text) + +This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input. + +The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**. + + + +### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text) + +The **FanqingM/MMIU-Benchmark** dataset consists of: + +- **Context:** Included in the system prompt. +- **Question:** Provided as part of the user's input. +- **Series of Images:** Multiple images related to the question. +- **Answer:** The model's expected response. + +This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs. + + + +## Developing a Fine-Tuning Script for Multimodal SFT + +In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases. + +### Setting Up the Environment + +Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment: + +```bash +# Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation +pip install -U -q trl bitsandbytes peft hf_xet tensorboard +``` + +Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required. + +If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it. + +To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account. + +```bash +huggingface-cli login +``` + +### **Loading the Data** + +As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent. + +This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario. + +#### **Single Image + Text** + +![Single Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_single_image.png) + +In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`. + +```python +from datasets import load_dataset + +dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft" + +# Load Dataset +dataset = load_dataset(dataset_name) +``` + +#### **Multi-Image + Text (or Interleaving)** + +![Multi-Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_multi_image.png) + +Gemma 3 also supports **Multi-Image + Text** scenarios, where: + +- The model receives a **list of images** alongside a user message. +- The model processes **interleaved images and text** within a conversation. + +For this dataset, some preprocessing is required before training. + +```python +from datasets import load_dataset + +dataset_name = "FanqingM/MMIU-Benchmark" + +# Load Dataset +dataset = load_dataset(dataset_name) +``` + +After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look: + +```python +{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, +{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]}, +{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]}, +``` + +Here, `images_list` is a list of images: + +```python +images_list = [ + {"type": "image", "image": }, + {"type": "image", "image": }, + {"type": "image", "image": }, + {"type": "image", "image": }, + {"type": "image", "image": }, +] +``` + +This structure can be translated into code like this: + +```python +import os +import zipfile +import io +from datasets import DatasetDict +from huggingface_hub import hf_hub_download, list_repo_files +from PIL import Image + +dataset_train_split = "test" + +def format_data(samples: dict[str, any]) -> dict[str, list]: + formatted_samples = {"messages": []} + for cont in range(len(samples["question"])): + images = [] + for img_path in samples["input_image_path"][cont]: + try: + with open(img_path, "rb") as f: + img_bytes = f.read() + image = Image.open(io.BytesIO(img_bytes)).convert("RGB") + images.append({"type": "image", "image": image}) + except Exception as e: + print(f"Error processing image {img_path}: {e}") + continue + + formatted_samples["messages"].append( + [ + {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, + {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, + {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, + ] + ) + return formatted_samples + +# For multi-image example +def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: + all_files = list_repo_files(dataset_name, repo_type="dataset") + zip_files = [f for f in all_files if f.endswith(".zip")] + + for zip_filename in zip_files: + zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") + extract_folder = zip_filename.replace(".zip", "") + os.makedirs(extract_folder, exist_ok=True) + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_folder) + + dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) + return dataset + +dataset = prepare_dataset(dataset, dataset_name, dataset_train_split) +``` + +With this, your **Multi-Image + Text** dataset is now prepared for training. + +### **Preparing for Training** + +We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models. + +To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model. + +```python +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig + +model_id = "google/gemma-3-4b-it" + +# BitsAndBytesConfig int-4 config +bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_storage=torch.bfloat16, +) + +# Load model and tokenizer +model = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934) + quantization_config=bnb_config +) +processor = AutoProcessor.from_pretrained(model_id) +processor.tokenizer.padding_side = "right" +``` + +Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs). + +```python +from peft import LoraConfig, get_peft_model + +# Configure QLoRA +peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.05, + r=16, + bias="none", + target_modules="all-linear", + task_type="CAUSAL_LM", + modules_to_save=[ + "lm_head", + "embed_tokens", + ], +) +``` + +With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs. + +```python +from trl import SFTConfig + +training_args = SFTConfig( + output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets). + num_train_epochs=1, # Set the number of epochs to train the model. + per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1 + gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1 + gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training. + optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance. + logging_steps=10, # Frequency of logging training progress (log every 10 steps). + save_strategy="epoch", # Save checkpoints at the end of each epoch. + learning_rate=2e-05, # Learning rate for training. + bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations. + push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training. + report_to="tensorboard", # Automatically report metrics to tensorboard. + gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues. + dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually. + remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing). +) +``` + +The `collate_fn` is responsible for processing and preparing individual examples to form a batch. + +Each example in the batch undergoes the following steps: + +1. The **chat template** is applied to the text. +2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors. +3. The **labels** for training are set as the `input_ids` of the example. +4. Certain **special tokens** are **masked (ignored)** during loss computation: + - `pad_token_id` + - `` + - `` (corresponding to ID `262144`) + +This process is similar across different dataset types, with a minor variation in how images are handled: + +- **Single Image + Text** → A **list of images** is directly processed. +- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images. + +```python +from PIL import Image + +# For multi-image cases +def process_vision_info(messages: list[dict]) -> list[Image.Image]: + image_inputs = [] + for msg in messages: + content = msg.get("content", []) + if not isinstance(content, list): + content = [content] + + for element in content: + if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): + if "image" in element: + image = element["image"] + else: + image = element + if image is not None: + image = Image.open(io.BytesIO(image["bytes"])) + image_inputs.append(image.convert("RGB")) + return image_inputs + +def collate_fn(examples): + texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples] + if "images" in examples[0]: # single-image + images = [ + [img.convert("RGB") for img in example["images"]] + for example in examples + ] + else: # multi-image + images = [process_vision_info(example["messages"]) for example in examples] + + # Tokenize the texts and process the images + batch = processor( + text=texts, images=images, return_tensors="pt", padding=True + ) # Encode texts and images into tensors + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() # Clone input IDs for labels + # Mask image tokens + image_token_id = [ + processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) + ] + # Mask tokens for not being used in the loss computation + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + + batch["labels"] = labels + return batch # Return the prepared batch +``` + +### **Training the Model** + +With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process. + +``` python +# Training +from trl import SFTTrainer + +trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"], + processing_class=processor, + peft_config=peft_config, +) + +trainer.train() + +# Save the final model +trainer.save_model() +``` + +We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration. + + +### Results + +During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example: + +* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft) + +* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark) + +## Limitations + +Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results. + +## References + +For further reading and complementary resources, check out the following: + +- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora) +- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) + diff --git a/docs/source/unsloth_integration.md b/docs/source/unsloth_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..dd071d392d67d4e0752d7dcd8d1a92cd527033e9 --- /dev/null +++ b/docs/source/unsloth_integration.md @@ -0,0 +1,7 @@ +# Unsloth Integration + + + +Section under construction. Feel free to contribute! + + \ No newline at end of file diff --git a/docs/source/use_model.md b/docs/source/use_model.md new file mode 100644 index 0000000000000000000000000000000000000000..f5ab1e45946460fc80d64f54136482b12400d059 --- /dev/null +++ b/docs/source/use_model.md @@ -0,0 +1,58 @@ +# Use model after training + +Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). + +## Load and Generate + +If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: + +```python +from transformers import AutoTokenizer, AutoModelForCausalLM + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +device = "cpu" # or "cuda" if you have a GPU + +model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) +tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + +inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) +outputs = model.generate(inputs) +print(tokenizer.decode(outputs[0])) +``` + +Alternatively you can also use the pipeline: + +```python +from transformers import pipeline + +model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub +pipe = pipeline("text-generation", model=model_name_or_path) +print(pipe("This movie was really")[0]["generated_text"]) +``` + +## Use Adapters PEFT + +```python +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer + +base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub" +adapter_model_name = "path/to/my/adapter" + +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +tokenizer = AutoTokenizer.from_pretrained(base_model_name) +``` + +You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: + +```python +model = AutoModelForCausalLM.from_pretrained(base_model_name) +model = PeftModel.from_pretrained(model, adapter_model_name) + +model = model.merge_and_unload() +model.save_pretrained("merged_adapters") +``` + +Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. diff --git a/docs/source/using_llama_models.md b/docs/source/using_llama_models.md new file mode 100644 index 0000000000000000000000000000000000000000..749dd6480584019ab394a329469ab65025bcca3c --- /dev/null +++ b/docs/source/using_llama_models.md @@ -0,0 +1,159 @@ +# Using LLaMA models with TRL + +We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model). + +## Efficient training strategies + +Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later. + +Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit. +For more on `peft` + `trl`, see the [docs](https://huggingface.co/docs/trl/sentiment_tuning_peft). + +Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory). +Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced. + +In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup. +This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost. + +Now we can fit very large models into a single GPU, but the training might still be very slow. +The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU. +With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs. + +![chapter10_ddp.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/chapter10_ddp.png) + +We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively. + +```bash +accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py +torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py +``` + +## Supervised fine-tuning + +Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in. +In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea. +The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task. +The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it. + +There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here. +To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding. + +![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/chapter10_preprocessing-clm.png) + +With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss. +If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader. + + +```python +# load model in 8bit +model = AutoModelForCausalLM.from_pretrained( + args.model_path, + load_in_8bit=True, + device_map={"": Accelerator().local_process_index} + ) +model = prepare_model_for_kbit_training(model) + +# add LoRA to model +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) + +model = get_peft_model(model, config) +``` + +We train the model for a few thousand steps with the causal language modeling objective and save the model. +Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights. + +**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections. +You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py). +Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released. + +Now that we have fine-tuned the model for the task, we are ready to train a reward model. + +## Reward modeling and human preferences + +In principle, we could fine-tune the model using RLHF directly with the human annotations. +However, this would require us to send some samples to humans for rating after each optimization iteration. +This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed. + +A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop. +The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”). +In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator. + +With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score. +With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function. + +```python +class RewardTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss +``` + +We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is: + +```python +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) +``` +As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use. + +## Reinforcement Learning from Human Feedback + +With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps: + +1. Generate responses from prompts, +2. Rate the responses with the reward model, +3. Run a reinforcement learning policy-optimization step with the ratings. + +The Query and Response prompts are templated as follows before being tokenized and passed to the model: + +```bash +Question: + +Answer: +``` + +The same template was used for SFT, RM and RLHF stages. +Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context. +Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training. +We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights. + +```python +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + question_tensors = batch["input_ids"] + + # sample from the policy and to generate responses + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute sentiment score + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + # Log stats to Wandb + ppo_trainer.log_stats(stats, batch, rewards) +``` + +For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama). diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md new file mode 100644 index 0000000000000000000000000000000000000000..9a3e7762b1a111a52afb4157c78d9e084d5deb5d --- /dev/null +++ b/docs/source/vllm_integration.md @@ -0,0 +1,185 @@ +# vLLM Integration + +This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 + +## 🚀 How can I use vLLM with TRL to speed up training? + +💡 **Note**: Resources required for this specific example: a single node with 8 GPUs. + +First, install vLLM using the following command: + +```bash +pip install "trl[vllm]" +``` + +Then run the server: + +```sh +trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2 +``` + +Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs. + +In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows: + +Sample of a simple `train.py` script: + +```python +from datasets import load_dataset +from trl import GRPOTrainer, GRPOConfig + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +training_args = GRPOConfig( + output_dir="my_test", + use_vllm=True, + bf16=True, + gradient_checkpointing=True, + logging_steps=10, +) + +trainer = GRPOTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) + +trainer.train() +``` + +And the train command: + +```sh +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +``` + +## 🎬 Flashback: Why do we need to use vLLM in online methods? + +Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods. + +## 🤔 How does vLLM solve the slow generation issue? + +If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details. + +## 🤔 What exactly happens when you run `trl vllm-serve --model `? + +When you run for example + +```sh +trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4 +``` + +the following happens: + +![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png) + +1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4). +Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load. + +2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas. + +3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts). +This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself. +Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings. + +## 🥸 More detail on what happens under the hood when running the server + +* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`. +* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035). +* The client (trainer) then requests these completions from the server. +* These completions are used to compute the reward signal. +* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights. +* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`. + +When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below: + +* **Set GPUs *0–3* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation. + +```sh +trl vllm-serve --model --tensor-parallel-size 1 --data-parallel-size 4 +``` + +* **And GPUs *4–7* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 4–7 for training, set the environment variable as follows: + +```sh +CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py +``` + +## 🍷 More customization options with vLLM? + +You can customize the server configuration by passing additional arguments. + +``` +$ trl vllm-serve --help +usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] + [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT] + [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN] + [--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL] + +options: + -h, --help Show this help message and exit + --model MODEL Model name or path to load the model from. (default: None) + --revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None) + --tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE + Number of tensor parallel workers to use. (default: 1) + --data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE + Number of data parallel workers to use. (default: 1) + --host HOST Host address to run the server on. (default: 0.0.0.0) + --port PORT Port to run the server on. (default: 8000) + --gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device + dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the + model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during + initialization. (default: 0.9) + --dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on + the model configuration. Find the supported values in the vLLM documentation. (default: auto) + --max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN + If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context + size, which might be much larger than the KV cache, leading to inefficiencies. (default: None) + --enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING + Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this + feature. (default: None) + --enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER + Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model + in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default: + None) + --log_level LOG_LEVEL, --log-level LOG_LEVEL + Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: + info) +``` + +## 🥳 Okay, now that we have the server running, how can we use it to generate completions? + +Run the training script and pass `use_vllm=True` in the training arguments: + +```python +from trl import GRPOConfig + +training_args = GRPOConfig(..., use_vllm=True) +``` + +## 💆🏻‍♀️ What's the best distributed setup? + +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png) +![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png) + +First and foremost, always remember that the optimal setup depends on: + +* The model size +* The number of GPUs you have +* The GPU memory size +* The batch size you are using +* The number of requests you are sending to the server (prompts) +* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size) +* The number of completions you are generating for each request (`num_generations`) + +Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that: + +* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results. +* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window. diff --git a/docs/source/xpo_trainer.md b/docs/source/xpo_trainer.md new file mode 100644 index 0000000000000000000000000000000000000000..587f4f6f3cc92a71782b3a55a7d18e751a2c8bac --- /dev/null +++ b/docs/source/xpo_trainer.md @@ -0,0 +1,162 @@ +# XPO Trainer + +[![](https://img.shields.io/badge/All_models-XPO-blue)](https://huggingface.co/models?other=xpo,trl) + +## Overview + +Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has emerged as a central tool for language model alignment. We consider online exploration in RLHF, which exploits interactive access to human or AI feedback by deliberately encouraging the model to produce diverse, maximally informative responses. By allowing RLHF to confidently stray from the pre-trained model, online exploration offers the possibility of novel, potentially super-human capabilities, but its full potential as a paradigm for language model training has yet to be realized, owing to computational and statistical bottlenecks in directly adapting existing reinforcement learning techniques. We propose a new algorithm for online exploration in RLHF, Exploratory Preference Optimization (XPO), which is simple and practical -- a one-line change to (online) Direct Preference Optimization (DPO; Rafailov et al., 2023) -- yet enjoys the strongest known provable guarantees and promising empirical performance. XPO augments the DPO objective with a novel and principled exploration bonus, empowering the algorithm to explore outside the support of the initial model and human feedback data. In theory, we show that XPO is provably sample-efficient and converges to a near-optimal language model policy under natural exploration conditions, irrespective of whether the initial model has good coverage. Our analysis, which builds on the observation that DPO implicitly performs a form of Q*-approximation (or, Bellman error minimization), combines previously disparate techniques from language modeling and theoretical reinforcement learning in a serendipitous fashion through the perspective of KL-regularized Markov decision processes. Empirically, we find that XPO is more sample-efficient than non-exploratory DPO variants in a preliminary evaluation. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Lewis Tunstall](https://huggingface.co/lewtun). + +## Quick start + +This example demonstrates how to train a model using the XPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here: + + +Below is the script to train the model: + +```python +# train_xpo.py +from datasets import load_dataset +from trl import PairRMJudge, XPOConfig, XPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +judge = PairRMJudge() +train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train") + +training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO", logging_steps=10) +trainer = XPOTrainer( + model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset +) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_xpo.py +``` + +Distributed across 8 GPUs, the training takes approximately 1 hour. + +To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models). + +
$ transformers chat trl-lib/Qwen2-0.5B-XPO
+<quentin_gallouedec>:
+What is the best programming language?
+
+<trl-lib/Qwen2-0.5B-XPO>:
+The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript. 
+
+ +## Expected dataset type + +XPO requires a [prompt-only dataset](dataset_formats#prompt-only). The [`XPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +## Usage tips + +### Use a reward model + +Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model: + +```diff +- from trl import PairRMJudge ++ from transformers import AutoModelForSequenceClassification + +- judge = PairRMJudge() ++ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1) + + trainer = XPOTrainer( + ... +- judge=judge, ++ reward_model=reward_model, + ) +``` + + + +Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training. + + + +### Encourage EOS token generation + +When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`XPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`XPOConfig`]: + +```python +training_args = XPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0) +``` + +### Logging Completions + +To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`]. + +```python +trainer = XPOTrainer(..., eval_dataset=eval_dataset) +completions_callback = LogCompletionsCallback(trainer, num_prompts=8) +trainer.add_callback(completions_callback) +``` + +This callback logs the model's generated completions directly to Weights & Biases. + +![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png) + +## Example script + +We provide an example script to train a model using the XPO method. The script is available in [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) + +To test the XPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command: + +```bash +python examples/scripts/xpo.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --judge pair_rm \ + --dataset_name trl-lib/ultrafeedback-prompt \ + --learning_rate 5.0e-7 \ + --logging_steps 25 \ + --output_dir Qwen2.5-0.5B-XPO-PairRM \ + --warmup_ratio 0.1 \ + --push_to_hub +``` + +## Logged metrics + +The logged metrics are as follows: + +* `loss/xpo`: The mean xpo part of the full loss. +* `loss/dpo`: The mean dpo part of the full loss. +* `objective/kl`: The mean KL divergence between the model and reference data. +* `objective/entropy`: The mean entropy of the model and reference data. +* `objective/model_scores`: The mean scores (according to the reward model) of the model completions. +* `objective/ref_scores`: The mean scores (according to the reward model) of the reference completions. +* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions. +* `rewards/chosen`: The mean reward (according to XPO's DPO implicit reward model) of the chosen completions. +* `rewards/rejected`: The mean reward (according to XPO's DPO implicit reward model) of the rejected completions. +* `rewards/accuracies`: The accuracies of the XPO's implicit reward model. +* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the rejected completions. +* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. +* `val/ref_contain_eos_token`: The amount of times the reference's output contains the eos token. +* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`]. +* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`]. + + +## XPOTrainer + +[[autodoc]] XPOTrainer + +## XPOConfig + +[[autodoc]] XPOConfig diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..37999e41abc02461a09ed7e29e39cc0bec15e488 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,3 @@ +# Examples + +Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples. \ No newline at end of file diff --git a/examples/accelerate_configs/deepspeed_zero1.yaml b/examples/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5b5f782fb30f9fcbcc8fc58262f09eaf2e10368 --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/deepspeed_zero2.yaml b/examples/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..239b14ac3a9ae8de73122d1154bf0d71903dc15f --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/deepspeed_zero3.yaml b/examples/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5a1201f8a2ee8706b63f0f80c664a1fc61a7d9d --- /dev/null +++ b/examples/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/fsdp1.yaml b/examples/accelerate_configs/fsdp1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c01b0b567bc93bf87ec136ea975b3793d273a45c --- /dev/null +++ b/examples/accelerate_configs/fsdp1.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/fsdp2.yaml b/examples/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af498f3eced9c2434b80113f2f22d40395e0ab8a --- /dev/null +++ b/examples/accelerate_configs/fsdp2.yaml @@ -0,0 +1,25 @@ +# Requires accelerate 1.7.0 or higher +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/multi_gpu.yaml b/examples/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15dad9be3ba44f7c934e1ecab98a93cb83cbc79a --- /dev/null +++ b/examples/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/single_gpu.yaml b/examples/accelerate_configs/single_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebd00a067118e56f3d63ab0f24827cfea21b24b9 --- /dev/null +++ b/examples/accelerate_configs/single_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: "NO" +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/cli_configs/example_config.yaml b/examples/cli_configs/example_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb44cec69e219e3435057b8437a1bb6198a339b0 --- /dev/null +++ b/examples/cli_configs/example_config.yaml @@ -0,0 +1,18 @@ +# This is an example configuration file of TRL CLI, you can use it for +# SFT like that: `trl sft --config config.yaml --output_dir test-sft` +# The YAML file supports environment variables by adding an `env` field +# as below + +# env: +# CUDA_VISIBLE_DEVICES: 0 + +model_name_or_path: + Qwen/Qwen2.5-0.5B +dataset_name: + stanfordnlp/imdb +report_to: + none +learning_rate: + 0.0001 +lr_scheduler_type: + cosine diff --git a/examples/datasets/hh-rlhf-helpful-base.py b/examples/datasets/hh-rlhf-helpful-base.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb16c3bb534e647a93b57cafe4ea44f4dcdf488 --- /dev/null +++ b/examples/datasets/hh-rlhf-helpful-base.py @@ -0,0 +1,133 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."} + ) + dataset_num_proc: Optional[int] = field( + default=None, metadata={"help": "Number of workers to use for dataset processing."} + ) + + +def common_start(str1: str, str2: str) -> str: + # Zip the two strings and iterate over them together + common_chars = [] + for c1, c2 in zip(str1, str2): + if c1 == c2: + common_chars.append(c1) + else: + break + # Join the common characters and return as a string + return "".join(common_chars) + + +def extract_dialogue(example: str) -> list[dict[str, str]]: + # Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues + prompt_text = common_start(example["chosen"], example["rejected"]) + + # The chosen and rejected may share a common start, so we need to remove the common part + if not prompt_text.endswith("\n\nAssistant: "): + prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: " + + # Extract the chosen and rejected lines + chosen_line = example["chosen"][len(prompt_text) :] + rejected_line = example["rejected"][len(prompt_text) :] + + # Remove the generation prompt ("\n\nAssistant: ") from the prompt + prompt_text = prompt_text[: -len("\n\nAssistant: ")] + + # Split the string at every occurrence of "Human: " or "Assistant: " + prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text) + + # Remove the first element as it's empty + prompt_lines = prompt_lines[1:] + + prompt = [] + for idx in range(0, len(prompt_lines), 2): + role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant" + content = prompt_lines[idx + 1] + prompt.append({"role": role, "content": content}) + + # Remove the prompt from the chosen and rejected dialogues + chosen = [{"role": "assistant", "content": chosen_line}] + rejected = [{"role": "assistant", "content": rejected_line}] + + return {"prompt": prompt, "chosen": chosen, "rejected": rejected} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# HH-RLHF-Helpful-Base Dataset + +## Summary + +The HH-RLHF-Helpful-Base dataset is a processed version of [Anthropic's HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and alignment tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the helpfulness of the responses. This dataset enables models to learn human preferences in generating helpful responses, enhancing their ability to assist users effectively. + +## Data Structure + +- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) +- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) + +Columns: +- `"prompt"`: The user query. +- `"chosen"`: A response deemed helpful by human evaluators. +- `"rejected"`: A response considered less helpful or unhelpful. + +This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in helpfulness. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/hh-rlhf-helpful-base.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base") + dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/lm-human-preferences-descriptiveness.py b/examples/datasets/lm-human-preferences-descriptiveness.py new file mode 100644 index 0000000000000000000000000000000000000000..e7dbf83a17b8797e2d8e7d93fc3c1ecc4d66233f --- /dev/null +++ b/examples/datasets/lm-human-preferences-descriptiveness.py @@ -0,0 +1,120 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/lm-human-preferences-descriptiveness", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +# Edge cases handling: remove the cases where all samples are the same +def samples_not_all_same(example): + return not all(example["sample0"] == example[f"sample{j}"] for j in range(1, 4)) + + +def to_prompt_completion(example, tokenizer): + prompt = tokenizer.decode(example["query"]).strip() + best_idx = example["best"] + chosen = tokenizer.decode(example[f"sample{best_idx}"]) + for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one + rejected = tokenizer.decode(example[f"sample{rejected_idx}"]) + if chosen != rejected: + break + assert chosen != rejected + return {"prompt": prompt, "chosen": chosen, "rejected": rejected} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# LM-Human-Preferences-Descriptiveness Dataset + +## Summary + +The LM-Human-Preferences-Descriptiveness dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on enhancing the descriptiveness of generated text. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the level of detail and vividness in the descriptions. This dataset enables models to learn human preferences in descriptive language, improving their ability to generate rich and engaging narratives. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) + +Columns: +- `"prompt"`: The text sample. +- `"chosen"`: A version of the text with enhanced descriptiveness. +- `"rejected"`: A version of the text with less descriptiveness. + +This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in descriptive language. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-descriptiveness.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset( + "json", + data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json", + split="train", + ) + + dataset = dataset.filter(samples_not_all_same, num_proc=script_args.dataset_num_proc) + + dataset = dataset.map( + to_prompt_completion, + num_proc=script_args.dataset_num_proc, + remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"], + fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")}, + ) + + # train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L79) + dataset = dataset.train_test_split(train_size=4992) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/lm-human-preferences-sentiment.py b/examples/datasets/lm-human-preferences-sentiment.py new file mode 100644 index 0000000000000000000000000000000000000000..fddf483c46ffc615cef356d93959a98d0dada692 --- /dev/null +++ b/examples/datasets/lm-human-preferences-sentiment.py @@ -0,0 +1,113 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/lm-human-preferences-sentiment", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_prompt_completion(example, tokenizer): + prompt = tokenizer.decode(example["query"]).strip() + best_idx = example["best"] + chosen = tokenizer.decode(example[f"sample{best_idx}"]) + for rejected_idx in range(4): # take the first rejected sample that is different from the chosen one + rejected = tokenizer.decode(example[f"sample{rejected_idx}"]) + if chosen != rejected: + break + assert chosen != rejected + return {"prompt": prompt, "chosen": chosen, "rejected": rejected} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# LM-Human-Preferences-Sentiment Dataset + +## Summary + +The LM-Human-Preferences-Sentiment dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on sentiment analysis tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the sentiment conveyed in the text. This dataset enables models to learn human preferences in sentiment expression, enhancing their ability to generate and evaluate text with desired emotional tones. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) + +Columns: +- `"prompt"`: The text sample. +- `"chosen"`: A version of the text that conveys the desired sentiment. +- `"rejected"`: A version of the text that does not convey the desired sentiment. + +This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in sentiment expression. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-sentiment.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset( + "json", + data_files="https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json", + split="train", + ) + + dataset = dataset.map( + to_prompt_completion, + num_proc=script_args.dataset_num_proc, + remove_columns=["query", "sample0", "sample1", "sample2", "sample3", "best"], + fn_kwargs={"tokenizer": AutoTokenizer.from_pretrained("gpt2")}, + ) + + # train_size taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70) + dataset = dataset.train_test_split(train_size=4992) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/math_shepherd.py b/examples/datasets/math_shepherd.py new file mode 100644 index 0000000000000000000000000000000000000000..6eec699ce5a6a1a1cb818815e21cc50a00b49034 --- /dev/null +++ b/examples/datasets/math_shepherd.py @@ -0,0 +1,170 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from dataclasses import dataclass, field +from itertools import chain +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/math_shepherd", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def process_example(example): + # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label" + inputs = example["input"].replace("ки", "ⶻ") + + # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label) + indexes = [m.start() for m in re.finditer("ⶻ", inputs)] + + # Sanity that all indexes are either "+" or "-" + assert all(example["label"][idx] in ["+", "-"] for idx in indexes) + + # Get the labels + labels = [example["label"][idx] == "+" for idx in indexes] + + # Split the inputs into steps (caution, the first step is missing here, it is the prompt) + steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))] + + # Remove the last step (single ⶻ) + steps = steps[:-1] + + # Get the prompt (first part) and completions (rest) + prompt = steps[0] + completions = steps[1:] + + # Remove the heading "ⶻ" and the final whitespace from the completions + assert all(completion.startswith("ⶻ") for completion in completions) + completions = [completion[1:].strip() for completion in completions] + + # At this point, we need to retrieve the first step from the prompt. + # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt. + if prompt.startswith( + ( + "Mr. Rocky", + "Parker", + "What is the smallest positive", + " The Myth", + "Let $\\mathbf{a}$", + "Find the arithmetic", + "Determine an ordered pair", + "Determine the ordered pair", + "At the Quill and Scroll stationery", + "Round to the nearest", + r"Calculate $\sqrt{10p}", + r"Simplify $\sqrt{28x}", + ) + ): + # Some spotted datasets errors where there is an annotation in the prompt: we remove it + labels = labels[1:] + + # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or + # (less common) "?". + elif "Step 1:" in prompt: + prompt, first_step = prompt.split("Step 1:") + first_step = "Step 1:" + first_step + completions = [first_step.strip()] + completions + elif "step 1:" in prompt: + prompt, first_step = prompt.split("step 1:") + first_step = "step 1:" + first_step + completions = [first_step.strip()] + completions + elif "?" in prompt: + prompt, first_step = prompt.split("?") + prompt = prompt + "?" + completions = [first_step.strip()] + completions + else: + raise ValueError(f"Prompt can't be processed: {prompt}") + + # Strip the prompt + prompt = prompt.strip() + + # Sanity check that the length of the completions is the same as the length of the labels + assert len(completions) == len(labels) + + return {"prompt": prompt, "completions": completions, "labels": labels} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# Math-Shepherd Dataset + +## Summary + +The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) + +Columns: +- `"prompt"`: The problem statement. +- `"completions"`: A list of reasoning steps generated to solve the problem. +- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. + +This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("peiyi9979/Math-Shepherd", split="train") + + dataset = dataset.map( + process_example, + remove_columns=["input", "label", "task"], + num_proc=script_args.dataset_num_proc, + ) + dataset = dataset.train_test_split(test_size=0.05, seed=42) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/prm800k.py b/examples/datasets/prm800k.py new file mode 100644 index 0000000000000000000000000000000000000000..2947755a080f730d8c7ad381d59b5ab6f8a697f2 --- /dev/null +++ b/examples/datasets/prm800k.py @@ -0,0 +1,157 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/prm800k", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def process_example(example): + outputs = [] + prompt = example["question"]["problem"] + + # Iterate through each step + previous_completions = [] + previous_labels = [] + for step in example["label"]["steps"]: + if step["completions"] is None and step["human_completion"] is None and step["chosen_completion"] is None: + # happens sometimes + break + # Loop through completions + for completion_idx, completion in enumerate(step["completions"]): + # For every completion that are not chosen, we are in a terminal state, so we can add it to the list of outputs. + if completion_idx != step["chosen_completion"]: + content = completion["text"] + completions = previous_completions[:] + [content] + label = completion["rating"] == 1 + labels = previous_labels[:] + [label] + outputs.append({"prompt": prompt, "completions": completions, "labels": labels}) + + # Now, exapand the previous completions and labels + if step["chosen_completion"] is not None: + chosen_completion = step["completions"][step["chosen_completion"]] + label = chosen_completion["rating"] == 1 + elif step["human_completion"] is not None: + chosen_completion = step["human_completion"] + label = True + else: + break + content = chosen_completion["text"] + previous_completions.append(content) + previous_labels.append(label) + + # Last step: we are in a terminal state, so we can add it to the list of outputs + outputs.append({"prompt": prompt, "completions": previous_completions, "labels": previous_labels}) + return outputs + + +def process_batch(examples): + outputs = [] + batch_size = len(examples["label"]) + for idx in range(batch_size): + example = {k: v[idx] for k, v in examples.items()} + outputs.extend(process_example(example)) + # list of dict to dict of list + outputs = {k: [v[k] for v in outputs] for k in outputs[0]} + return outputs + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# PRM800K Dataset + +## Summary + +The PRM800K dataset is a processed version of [OpenAI's PRM800K](https://github.com/openai/prm800k), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It contains 800,000 step-level correctness labels for model-generated solutions to problems from the MATH dataset. This dataset enables models to learn and verify each step of a solution, enhancing their reasoning capabilities. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision) + +Columns: +- `"prompt"`: The problem statement. +- `"completions"`: A list of reasoning steps generated to solve the problem. +- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step. + +This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + data_files = { + "train": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_train.jsonl", + "test": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_test.jsonl", + } + dataset = load_dataset("json", data_files=data_files) + + dataset = dataset.map( + process_batch, + batched=True, + batch_size=10, + remove_columns=[ + "labeler", + "timestamp", + "generation", + "is_quality_control_question", + "is_initial_screening_question", + "question", + "label", + ], + num_proc=script_args.dataset_num_proc, + ) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/rlaif-v.py b/examples/datasets/rlaif-v.py new file mode 100644 index 0000000000000000000000000000000000000000..093ae98a8406c58802a91467e4580c8dd53d9c81 --- /dev/null +++ b/examples/datasets/rlaif-v.py @@ -0,0 +1,113 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import features, load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/rlaif-v", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_conversational(example): + """ + Convert prompt from "xxx" to [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "xxx"}]}] + and chosen and rejected from "xxx" to [{"role": "assistant", "content": [{"type": "text", "text": "xxx"}]}]. + Images are wrapped into a list. + """ + prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": example["question"]}]}] + chosen = [{"role": "assistant", "content": [{"type": "text", "text": example["chosen"]}]}] + rejected = [{"role": "assistant", "content": [{"type": "text", "text": example["rejected"]}]}] + return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# RLAIF-V Dataset + +## Summary + +The RLAIF-V dataset is a processed version of the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset#dataset-card-for-rlaif-v-dataset), specifically curated to train vision-language models using the [TRL library](https://github.com/huggingface/trl) for preference learning tasks. It contains 83,132 high-quality comparison pairs, each comprising an image and two textual descriptions: one preferred and one rejected. This dataset enables models to learn human preferences in visual contexts, enhancing their ability to generate and evaluate image captions. + +## Data Structure + +- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) +- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) + +Columns: +- `"prompt"`: The task related to the image. +- `"images"`: The image. +- `"chosen"`: The preferred answer. +- `"rejected"`: An alternative answer that was not preferred. + +This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in visual tasks. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/rlaif-v.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train") + dataset = dataset.map( + to_conversational, + num_proc=script_args.dataset_num_proc, + remove_columns=dataset.column_names, + writer_batch_size=128, + ) + + # Cast the images to Sequence[Image] to avoid bytes format + f = dataset.features + f["images"] = features.Sequence(features.Image(decode=True)) + dataset = dataset.cast(f) + + dataset = dataset.train_test_split(test_size=0.01, writer_batch_size=128) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/tldr.py b/examples/datasets/tldr.py new file mode 100644 index 0000000000000000000000000000000000000000..2c905f849d18bec81984fbfcbc21531a3ce5336c --- /dev/null +++ b/examples/datasets/tldr.py @@ -0,0 +1,105 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/tldr", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_prompt_completion(example): + tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" + prompt = tldr_format_str.format(subreddit=example["subreddit"], title=example["title"], post=example["post"]) + completion = " " + example["summary"] # Add a space to separate the prompt from the completion + return {"prompt": prompt, "completion": completion} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# TL;DR Dataset + +## Summary + +The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for summarization tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training summarization models. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion) + +Columns: +- `"prompt"`: The unabridged Reddit post. +- `"completion"`: The concise "TL;DR" summary appended by the author. + +This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + # Filtered reddit TL;DR dataset from https://github.com/openai/summarize-from-feedback?tab=readme-ov-file#reddit-tldr-dataset + data_files = { + "train": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/train.jsonl", + "validation": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/valid.jsonl", + "test": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/test.jsonl", + } + dataset = load_dataset("json", data_files=data_files) + + dataset = dataset.map( + to_prompt_completion, + num_proc=script_args.dataset_num_proc, + remove_columns=["id", "subreddit", "title", "post", "summary"], + ) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/tldr_preference.py b/examples/datasets/tldr_preference.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce9db360dec6d720e51c69614e91b23d6692c7e --- /dev/null +++ b/examples/datasets/tldr_preference.py @@ -0,0 +1,111 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/tldr-preference", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_preference(example): + info = example["info"] + if example["batch"] in ["batch0_cnndm", "cnndm0", "cnndm2"]: # CNN Daily Mail batches + article = info["article"].replace("\n\n", "\n") + prompt = f"TITLE: {info['title']}\n\n{article}\n\nTL;DR:" + elif example["batch"] in [f"batch{i}" for i in range(3, 23)] + ["edit_b2_eval_test"]: # Reddit batches + post = info["post"].replace("\n\n", "\n") + prompt = f"SUBREDDIT: r/{info['subreddit']}\n\nTITLE: {info['title']}\n\nPOST: {post}\n\nTL;DR:" + else: + raise ValueError(f"Unknown batch: {example['batch']}") + + chosen_idx = example["choice"] + rejected_idx = 1 - chosen_idx + chosen = example["summaries"][chosen_idx]["text"] + rejected = example["summaries"][rejected_idx]["text"] + return {"prompt": prompt, "chosen": chosen, "rejected": rejected} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# TL;DR Dataset for Preference Learning + +## Summary + +The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and Reinforcement Learning from Human Feedback (RLHF) tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training models to understand and generate concise summaries. + +## Data Structure + +- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) +- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference) + +Columns: +- `"prompt"`: The unabridged Reddit post. +- `"chosen"`: The concise "TL;DR" summary appended by the author. +- `"rejected"`: An alternative summary or response that was not selected. + +This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr_preference.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("openai/summarize_from_feedback", "comparisons") + + dataset = dataset.map( + to_preference, + num_proc=script_args.dataset_num_proc, + remove_columns=["info", "summaries", "choice", "worker", "batch", "split", "extra"], + ) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/ultrafeedback-prompt.py b/examples/datasets/ultrafeedback-prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..515a4d03d8a155772e688cd179c76215377c3952 --- /dev/null +++ b/examples/datasets/ultrafeedback-prompt.py @@ -0,0 +1,103 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/ultrafeedback-prompt", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_unpaired_preference(example): + prompt = [{"role": "user", "content": example["instruction"]}] + return {"prompt": prompt} + + +def drop_long_prompt(example): + if len(example["prompt"][0]["content"]) > 512: + return False + else: + return True + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# UltraFeedback - Prompts Dataset + +## Summary + +The UltraFeedback - Prompts dataset is a processed version of the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset for model evaluation on specific aspects like helpfulness, honesty, and instruction-following. + +## Data Structure + +- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) +- **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only) + +Column: +- `"prompt"`: The input question or instruction provided to the model. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback-prompt.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("openbmb/UltraFeedback", split="train") + + dataset = dataset.map( + to_unpaired_preference, + remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"], + num_proc=script_args.dataset_num_proc, + ) + dataset = dataset.filter(drop_long_prompt) + dataset = dataset.train_test_split(test_size=0.05, seed=42) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/datasets/ultrafeedback.py b/examples/datasets/ultrafeedback.py new file mode 100644 index 0000000000000000000000000000000000000000..2b132195fed234a32e88a084cc8d1831dbd40619 --- /dev/null +++ b/examples/datasets/ultrafeedback.py @@ -0,0 +1,145 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from huggingface_hub import ModelCard +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + model_name (`str`, *optional*, defaults to `"gpt-3.5-turbo"`): + Language model to target. Possible values are: + aspect (`str`, *optional*, defaults to `"helpfulness"`): + Aspect to target. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"`): + Hugging Face repository ID to push the dataset to. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of workers to use for dataset processing. + """ + + model_name: str = field( + default="gpt-3.5-turbo", + metadata={ + "help": "Language model to target.", + "choices": [ + "alpaca-7b", + "bard", + "falcon-40b-instruct", + "gpt-3.5-turbo", + "gpt-4", + "llama-2-13b-chat", + "llama-2-70b-chat", + "llama-2-7b-chat", + "mpt-30b-chat", + "pythia-12b", + "starchat", + "ultralm-13b", + "ultralm-65b", + "vicuna-33b", + "wizardlm-13b", + "wizardlm-70b", + "wizardlm-7b", + ], + }, + ) + aspect: str = field( + default="helpfulness", + metadata={ + "help": "Aspect to target. Possible values are: 'helpfulness' (default), 'honesty', " + "'instruction-following', 'truthfulness'.", + "choices": ["helpfulness", "honesty", "instruction-following", "truthfulness"], + }, + ) + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of workers to use for dataset processing."}, + ) + + +def to_unpaired_preference(example, model_name, aspect): + prompt = [{"role": "user", "content": example["instruction"]}] + model_index = example["models"].index(model_name) + response_content = example["completions"][model_index]["response"] + completion = [{"role": "assistant", "content": response_content}] + score = int(example["completions"][model_index]["annotations"][aspect]["Rating"]) + label = score >= 5 + return {"prompt": prompt, "completion": completion, "label": label} + + +model_card = ModelCard(""" +--- +tags: [trl] +--- + +# UltraFeedback GPT-3.5-Turbo Helpfulness Dataset + +## Summary + +The UltraFeedback GPT-3.5-Turbo Helpfulness dataset contains processed user-assistant interactions filtered for helpfulness, derived from the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. It is designed for fine-tuning and evaluating models in alignment tasks. + +## Data Structure + +- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) +- **Type**: [Unpaired preference](https://huggingface.co/docs/trl/main/dataset_formats#unpaired-preference) + +Column: +- `"prompt"`: The input question or instruction provided to the model. +- `"completion"`: The model's response to the prompt. +- `"label"`: A binary value indicating whether the response is sufficiently helpful. + +## Generation script + +The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback.py). +""") + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + dataset = load_dataset("openbmb/UltraFeedback", split="train") + + dataset = dataset.filter( + lambda example: script_args.model_name in example["models"], + batched=False, + num_proc=script_args.dataset_num_proc, + ) + dataset = dataset.map( + to_unpaired_preference, + remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"], + fn_kwargs={"model_name": script_args.model_name, "aspect": script_args.aspect}, + num_proc=script_args.dataset_num_proc, + ) + dataset = dataset.train_test_split(test_size=0.05, seed=42) + + if script_args.push_to_hub: + dataset.push_to_hub(script_args.repo_id) + model_card.push_to_hub(script_args.repo_id, repo_type="dataset") diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0f2a5a8c71a7786c26481714d10d1340b29cbdaf --- /dev/null +++ b/examples/notebooks/README.md @@ -0,0 +1,7 @@ +# Notebooks + +This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications. + +- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. +- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. +- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. diff --git a/examples/notebooks/best_of_n.ipynb b/examples/notebooks/best_of_n.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..75ea2815e97f0e9a008aeb62a1b8fe2aafd2df4d --- /dev/null +++ b/examples/notebooks/best_of_n.ipynb @@ -0,0 +1,662 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "WQpNapZNWuXP" + }, + "source": [ + "\n", + "**Best-of-n sampling as an alternative to RLHF**\n", + "\n", + "This notebook compares reward-model scores of prompt based responses from \n", + "1. a base model (`gpt2-imdb`)\n", + "2. `RLHF` tuned model based on this base-model \n", + "3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n", + "\n", + "Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vDA6qayz692w" + }, + "outputs": [], + "source": [ + "%pip install transformers trl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "M1s_iNm773hM" + }, + "outputs": [], + "source": [ + "import torch\n", + "import pandas as pd\n", + "\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y7hyrIrO8tcY" + }, + "source": [ + "Various constants" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "MqS3OM6Q8x6g" + }, + "outputs": [], + "source": [ + "ref_model_name = \"lvwerra/gpt2-imdb\"\n", + "model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n", + "reward_model = \"lvwerra/distilbert-imdb\"\n", + "\n", + "N_BEST_OF = 4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "c1YcXeElg6or" + }, + "source": [ + "Models and tokenizers" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "b855NrL181Hh" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kashif/Github/transformers/src/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "AutoModelForCausalLMWithValueHead(\n", + " (pretrained_model): GPT2LMHeadModel(\n", + " (transformer): GPT2Model(\n", + " (wte): Embedding(50257, 768)\n", + " (wpe): Embedding(1024, 768)\n", + " (drop): Dropout(p=0.1, inplace=False)\n", + " (h): ModuleList(\n", + " (0-11): 12 x GPT2Block(\n", + " (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (attn): GPT2SdpaAttention(\n", + " (c_attn): Conv1D(nf=2304, nx=768)\n", + " (c_proj): Conv1D(nf=768, nx=768)\n", + " (attn_dropout): Dropout(p=0.1, inplace=False)\n", + " (resid_dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " (mlp): GPT2MLP(\n", + " (c_fc): Conv1D(nf=3072, nx=768)\n", + " (c_proj): Conv1D(nf=768, nx=3072)\n", + " (act): NewGELUActivation()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", + " )\n", + " (v_head): ValueHead(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (summary): Linear(in_features=768, out_features=1, bias=True)\n", + " (flatten): Flatten(start_dim=1, end_dim=-1)\n", + " )\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n", + "\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n", + "\n", + "reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "# cuda-ize models\n", + "model.to(device)\n", + "ref_model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Z1Cz0gCFhZYJ" + }, + "source": [ + "Dataset building" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "LqLVEp5p_8XM" + }, + "outputs": [], + "source": [ + "def build_dataset(\n", + " tokenizer,\n", + " dataset_name=\"stanfordnlp/imdb\",\n", + " input_min_text_length=2,\n", + " input_max_text_length=8,\n", + "):\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds\n", + "\n", + "\n", + "dataset = build_dataset(tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "AqA2McjMAxNw" + }, + "outputs": [], + "source": [ + "gen_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.eos_token_id,\n", + "}\n", + "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "L_q4qs35AxcR" + }, + "outputs": [], + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "#### get a batch from the dataset\n", + "bs = 16\n", + "output_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "output_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "# :: [Resp]\n", + "response_tensors_ref, response_tensors = [], []\n", + "# :: [[Resp]]\n", + "response_tensors_best_of = []" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QVfpyHnZBLKY" + }, + "source": [ + "\n", + "Generation using various models" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "-imZ7uEFBNbw" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n" + ] + } + ], + "source": [ + "for i in range(bs):\n", + " gen_len = output_length_sampler()\n", + "\n", + " query = torch.tensor(query_tensors[i])\n", + "\n", + " output = ref_model.generate(\n", + " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()\n", + " response_tensors_ref.append(tokenizer.decode(output))\n", + "\n", + " output = model.generate(\n", + " query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()\n", + " response_tensors.append(tokenizer.decode(output))\n", + "\n", + " # generating copies of the same query for the Best-of-n sampling\n", + " queries = query.repeat((N_BEST_OF, 1))\n", + " output = ref_model.generate(\n", + " queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()\n", + " response_tensors_best_of.append(tokenizer.batch_decode(output))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jp5FC0Y5h_Sf" + }, + "source": [ + "Scoring" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "PyDbbAQ0F_h7" + }, + "outputs": [], + "source": [ + "scores_ref = [\n", + " output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n", + "]\n", + "scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n", + "scores_best_of = []\n", + "for i, response in enumerate(response_tensors_best_of):\n", + " # base_score = scores_ref[i]\n", + " scores_best_of.append(\n", + " torch.tensor(\n", + " [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n", + " )\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 682 + }, + "id": "nA1GDNJEiGm-", + "outputId": "1389c686-0751-4304-dea2-b71fd68748e1" + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (ref)scores (ref)response (RLHF)scores (RLHF)response (best_of)scores (best_of)
0This movieThis movie should have read some books, and1.411889This movie has plenty of extraordinary feature...2.735337This movie was unexpectedly funny and funny, you2.405301
1OK where do i begin?OK where do i begin? *** Acting is decent (not...1.555380OK where do i begin? For all of you who are no...0.019694OK where do i begin? i just wanted to add some...0.622912
2I watchedI watched one can compare themselves upon view...1.380120I watched it because of its excellent cast. Th...2.498309I watched the trial trial for teaches us a goo...2.057187
3It's been 19 years since GordonIt's been 19 years since Gordon finally left c...1.554914It's been 19 years since Gordon Tree has becom...1.632266It's been 19 years since Gordon Clarke put me ...2.783458
4Just kiddingJust kidding; I know a lot-0.069533Just kidding \"Third World Snopes0.944632Just kidding, I didn't even1.945202
5shakespeare's plays have a wayshakespeare's plays have a way of weaving into...1.656927shakespeare's plays have a way. It's the look ...1.444803shakespeare's plays have a way of getting back...1.834373
6This movie is wonderful. WhatThis movie is wonderful. What could have been ...2.749068This movie is wonderful. What someone likes ab...2.759510This movie is wonderful. What a different look,2.695312
7I lovedI loved this film. <br /><2.576181I loved it, and I really loved Audrey2.578412I loved this film. Reading reviews of it2.751773
8A superb andA superb and very cool drama. The novel is2.910374A superb and super fun movie that removes all the2.783201A superb and most finely acted role that I will2.894923
9I rememberI remember.Very poor execution but good movies0.923775I remember when Shelter saw some girls on TV0.825408I remember thinking to myself how SOMEONE who1.634163
10This su*kThis su*k camel down your kidd1.605957This su*k Dress! I loved it2.345865This su*k like a roll of crap2.422874
11One StinkOne Stink Act...<br /><br1.456476One Stinkl was a great actor, particularly1.782818One Stink?: Invisible of Saint Barbara, poor1.667756
12I pulled down a VHSI pulled down a VHS copy and watched it with m...0.756151I pulled down a VHS looking a good looking, and a-0.008258I pulled down a VHS copy the other day and all I0.992919
13For someFor some alone no more Buddy Trumbull would ha...0.790762For some enthraled time, the film will impress...2.455694For some reason, a bomb crashed on the rear of...0.857423
14This one features allThis one features all the good elements of spi...1.452079This one features all kinds of wit and humor r...2.743043This one features all the best Birdprogram sup...2.343950
15Somehow a woman working withSomehow a woman working with Jim Wynorski prof...0.242172Somehow a woman working with her daughter play...0.092226Somehow a woman working with an overweight ins...1.415525
\n", + "
" + ], + "text/plain": [ + " query \\\n", + "0 This movie \n", + "1 OK where do i begin? \n", + "2 I watched \n", + "3 It's been 19 years since Gordon \n", + "4 Just kidding \n", + "5 shakespeare's plays have a way \n", + "6 This movie is wonderful. What \n", + "7 I loved \n", + "8 A superb and \n", + "9 I remember \n", + "10 This su*k \n", + "11 One Stink \n", + "12 I pulled down a VHS \n", + "13 For some \n", + "14 This one features all \n", + "15 Somehow a woman working with \n", + "\n", + " response (ref) scores (ref) \\\n", + "0 This movie should have read some books, and 1.411889 \n", + "1 OK where do i begin? *** Acting is decent (not... 1.555380 \n", + "2 I watched one can compare themselves upon view... 1.380120 \n", + "3 It's been 19 years since Gordon finally left c... 1.554914 \n", + "4 Just kidding; I know a lot -0.069533 \n", + "5 shakespeare's plays have a way of weaving into... 1.656927 \n", + "6 This movie is wonderful. What could have been ... 2.749068 \n", + "7 I loved this film.
< 2.576181 \n", + "8 A superb and very cool drama. The novel is 2.910374 \n", + "9 I remember.Very poor execution but good movies 0.923775 \n", + "10 This su*k camel down your kidd 1.605957 \n", + "11 One Stink Act...

Optimise GPT2 to produce IMDB movie reviews with controlled sentiment using a BERT sentiment classifier for rewards.\n", + "\n", + "**WARNING:** We often experienced loss spikes in this examples which caused model training to fail or slow down. There is a [GitHub issue](https://github.com/lvwerra/trl/issues/101) to track the issue." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "The experiment setup is very similar to the positive sentiment notebook. However, in this notebook we fine-tune GPT2 (small) to generate **controlled** movie reviews based on the IMDB dataset. The model gets the target sentiment and 5 tokens from a real review and is tasked to produce continuations with the targeted sentiment. The reward for the continuations is calculated with the logits of a BERT sentiment classifier. That reward is then used for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import random\n", + "import torch\n", + "import wandb\n", + "import time\n", + "import os\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import pandas as pd\n", + "from random import choices\n", + "import matplotlib.pyplot as plt\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from datasets import load_dataset\n", + "\n", + "from transformers import AutoTokenizer, pipeline\n", + "\n", + "from trl import (\n", + " PPOTrainer,\n", + " PPOConfig,\n", + " AutoModelForCausalLMWithValueHead,\n", + " create_reference_model,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "sentiment_pipe_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\"}\n", + "\n", + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\",\n", + " steps=51200,\n", + " learning_rate=1.41e-5,\n", + " remove_unused_columns=False,\n", + " log_with=\"wandb\",\n", + ")\n", + "\n", + "txt_in_len = 5\n", + "txt_out_len = 20\n", + "seed = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "np.random.seed(seed)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "gpt2_ref_model = create_reference_model(gpt2_model)\n", + "gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 500 characters long and take the first 1000 characters of each comment. The first filter we apply to avoid comments that are less than `txt_in_len` token long and the second to avoid tokenizing way more text than we actually need." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Found cached dataset imdb (/home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-d314b4c14499bf03.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-0d5fcb05c95b1186.arrow\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['review', 'sentiment'],\n", + " num_rows: 22578\n", + "})" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# create the dataset\n", + "#\n", + "dataset = load_dataset(\"stanfordnlp/imdb\", split=\"train\")\n", + "dataset = dataset.rename_columns({\"text\": \"review\", \"label\": \"sentiment\"})\n", + "# make sure the comments are are at least 500 and trim to 1000\n", + "dataset = dataset.filter(lambda x: len(x[\"review\"]) > 500, batched=False)\n", + "dataset = dataset.map(lambda x: {\"review\": x[\"review\"][:1000]}, batched=False)\n", + "\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tokenize IMDB reviews" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We tokenize all IMDB in advance to avoid tokenizing twice. In the first step we encode the queries and slice the first `txt_in_len` tokens. In a second step we decode these tokens back to text for later display." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-383f6ebf0ae41ee4.arrow\n", + "Loading cached processed dataset at /home/leandro_huggingface_co/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-f4875ad4fccbbc1f.arrow\n" + ] + } + ], + "source": [ + "dataset = dataset.map(\n", + " lambda x: {\n", + " \"input_ids\": gpt2_tokenizer.encode(\" \" + x[\"review\"], return_tensors=\"pt\")[\n", + " 0, :txt_in_len\n", + " ]\n", + " },\n", + " batched=False,\n", + ")\n", + "dataset = dataset.map(\n", + " lambda x: {\"query\": gpt2_tokenizer.decode(x[\"input_ids\"])}, batched=False\n", + ")\n", + "dataset = dataset[:20480]\n", + "\n", + "from datasets import Dataset\n", + "\n", + "dataset = Dataset.from_dict(dataset)\n", + "dataset.set_format(\"pytorch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 770, 2646, 373, 2192, 7867])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset[3][\"input_ids\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlvwerra\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.9" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/leandro_huggingface_co/trl/examples/sentiment/notebooks/wandb/run-20230206_125743-jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run comic-music-184 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/lvwerra/trl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/lvwerra/trl/runs/jpcnr7jx" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ppo_trainer = PPOTrainer(\n", + " config, gpt2_model, gpt2_ref_model, gpt2_tokenizer, dataset, data_collator=collator\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "else:\n", + " device = ppo_trainer.accelerator.device\n", + "sentiment_pipe = pipeline(\n", + " \"sentiment-analysis\", \"lvwerra/distilbert-imdb\", device=device\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'NEGATIVE', 'score': 2.3350484371185303},\n", + " {'label': 'POSITIVE', 'score': -2.726576328277588}]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", + " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 0.8562759160995483},\n", + " {'label': 'NEGATIVE', 'score': -0.7086048126220703}]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was a documentary\"\n", + "output = sentiment_pipe(text, **sentiment_pipe_kwargs)\n", + "output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The resulting reward signal:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_pipe_output(outputs):\n", + " positive_logits = []\n", + " for out in outputs:\n", + " for element in out:\n", + " if element[\"label\"] == \"POSITIVE\":\n", + " positive_logits.append(torch.tensor(element[\"score\"]))\n", + " return positive_logits" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "-0.7086048126220703" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output[1][\"score\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Control token dict\n", + "We will append the control token at the beginning of each query to signal the model what the target sentiment is. Each control sequence consists of three tokens:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "ctrl_str = [\"[negative]\", \"[neutral]\", \"[positive]\"]\n", + "device = torch.device(\n", + " \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + ") # this should be handled by accelerate\n", + "ctrl_tokens = dict(\n", + " (s, gpt2_tokenizer.encode(s, return_tensors=\"pt\").squeeze().to(device))\n", + " for s in ctrl_str\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'[negative]': tensor([ 58, 31591, 60], device='cuda:0'),\n", + " '[neutral]': tensor([ 58, 29797, 60], device='cuda:0'),\n", + " '[positive]': tensor([ 58, 24561, 60], device='cuda:0')}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ctrl_tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward function" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "def pos_logit_to_reward(logit, task):\n", + " \"\"\"\n", + " Take the positive sentiment logit and scale it for the task.\n", + " task [negative]: reward = -logit\n", + " task [neutral]: reward = -2*abs(logit)+4\n", + " task [positive]: reward = logit\n", + " \"\"\"\n", + " for i in range(len(logit)):\n", + " if task[i] == \"[negative]\":\n", + " logit[i] = -logit[i]\n", + " elif task[i] == \"[neutral]\":\n", + " logit[i] = -2 * torch.abs(logit[i]) + 4\n", + " elif task[i] == \"[positive]\":\n", + " pass\n", + " else:\n", + " raise ValueError(\"task has to be in [0, 1, 2]!\")\n", + " return logit" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following examples show the rewards for the cases where the classifier logit is 4, -4 and 0 for the three targets `['negative]`, `['neutral]` and `['positive']`. The scaling is not perfect as it differs between neutral and the other two classes. This is something to further investigate in the future. Ideally, one would use the logit output for each class individually, but since there is no dedicated class for neutral this is a workaround." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['[negative]', '[neutral]', '[positive]']\n" + ] + } + ], + "source": [ + "print(ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-4., -4., 4.])" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([4, 4, 4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 4., -4., -4.])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([-4, -4, -4]), ctrl_str)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([-0., 4., 0.])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_logit_to_reward(torch.Tensor([0, 0, 0]), ctrl_str)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": gpt2_tokenizer.eos_token_id,\n", + " \"max_new_tokens\": txt_out_len,\n", + " \"eos_token_id\": -1,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Steps**\n", + "\n", + "The training loop consists of the following steps:\n", + "1. Get a batch of queries and create random controls\n", + "2. Get the query responses from the policy\n", + "3. Join query and responses and tokenize for BERT analysis\n", + "4. Get sentiments for query/responses from BERT\n", + "5. Optimize policy with PPO using the (query, response, reward) triplet\n", + "6. Log all the training statistics\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a P6000 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 8%|▊ | 6/80 [12:44<2:37:54, 128.03s/it]/home/leandro_huggingface_co/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1045: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n", + " warnings.warn(\n", + "100%|██████████| 80/80 [2:46:39<00:00, 124.99s/it] \n", + " 91%|█████████▏| 73/80 [2:30:39<14:35, 125.03s/it] " + ] + } + ], + "source": [ + "for epoch in range(2):\n", + " for batch in tqdm(ppo_trainer.dataloader):\n", + " (\n", + " logs,\n", + " game_data,\n", + " ) = (\n", + " dict(),\n", + " dict(),\n", + " )\n", + "\n", + " #### prepend a random control token\n", + " task_list = choices(ctrl_str, k=config.batch_size)\n", + " game_data[\"query\"] = [t + q for t, q in zip(task_list, batch[\"query\"])]\n", + " query_tensors = [\n", + " torch.cat((ctrl_tokens[t], input_ids))\n", + " for t, input_ids in zip(task_list, batch[\"input_ids\"])\n", + " ]\n", + "\n", + " #### get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " response = ppo_trainer.generate(query, **generation_kwargs)\n", + " response_tensors.append(response.squeeze()[-txt_out_len:])\n", + " game_data[\"response\"] = [\n", + " gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors\n", + " ]\n", + "\n", + " #### sentiment analysis\n", + " texts = [q + r for q, r in zip(batch[\"query\"], game_data[\"response\"])]\n", + " logits = extract_pipe_output(sentiment_pipe(texts, **sentiment_pipe_kwargs))\n", + " rewards = pos_logit_to_reward(logits, task_list)\n", + "\n", + " #### Run PPO training\n", + " t = time.time()\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + "\n", + " for cs in ctrl_str:\n", + " key = \"env/reward_\" + cs.strip(\"[]\")\n", + " stats[key] = np.mean(\n", + " [r.cpu().numpy() for r, t in zip(rewards, task_list) if t == cs]\n", + " )\n", + " ppo_trainer.log_stats(stats, game_data, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the following:\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher inital coefficient." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Reward distribution\n", + "First, we can have a look at the reward distribution. Both the negative and positive rewards are clearly shifted to high rewards. The neutral rewards, however, are still centered around zero. There are a few possible explanations for this. There could be a bug in the code and the way the neutral rewards are calculated. Another problem could be that sentence sometimes start with a strong sentiment and it is hard for the model shift the sentiment towards neutral." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGzCAYAAAAMr0ziAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPCUlEQVR4nO3deVwVZf8//tecw4HDroiyibKImSmQkKi5B6K3mt4tLvm4RSq7S7lvjTtNLAVcPqip0aLZnbdL3ZK0qP2+5o0SSVmiFor7lklubGqIgB4OnPn9YWfyyGE5h+UM8Ho+Hjw8c80117znOoPzZuaaGUEURRFEREREMqawdABEREREdWHCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkQNtmnTJgiCgNzcXLOWnzZtGnx8fAzKBEFAQkJCg2OrS2ZmJgRBQGZmplQ2dOhQ9OrVq8nXDQC5ubkQBAGbNm1qlvURtVRMWIio1UhJSUFycrKlwzBKzrERtQRWlg6AiMiYO3fuwMrKtP+iUlJScOLECcyePbveywwePBh37tyBtbW1iRGapqbYunbtijt37kClUjXp+olaOp5hIZKBsrIyS4dQK51Oh7t37zbrOtVqtckJiynu3r0LnU4HhUIBtVoNhcIy/x0KggC1Wg2lUmmR9RO1FExYiJpZQkICBEHAqVOn8Nxzz6F9+/YYOHCgNP+///0vQkJCYGtrCxcXF0yaNAmXL1+W5r/77rtQKpUoLi6WylatWgVBEBAbGyuVVVVVwdHREa+//rpUtnLlSgwYMAAdOnSAra0tQkJC8MUXX1SLURAExMTEYMuWLXjkkUdgY2ODtLQ0AMDJkycxfPhw2NraonPnzliyZAl0Ol29t3/Hjh3o1asX1Go1evXqhe3btxut9+AYltu3b2P27Nnw8fGBjY0NOnXqhIiICBw+fBjAvXEnX3/9NX777TcIggBBEKRxMfpxKlu3bsWbb74JLy8v2NnZoaSkxOgYFr3s7GwMGDAAtra28PX1xbp16wzm1zR258E2a4utpjEs3377LQYNGgR7e3u0a9cO48aNw+nTpw3q6PelX375BdOmTUO7du3g7OyM6OholJeX1/wlELVAvCREZCHPPvssAgIC8H//938QRREAsHTpUixYsAATJkzAiy++iKKiIrz33nsYPHgwjhw5gnbt2mHQoEHQ6XT44YcfMGbMGADAvn37oFAosG/fPqn9I0eOoLS0FIMHD5bK3nnnHTz55JOYMmUKKioqsHXrVjz77LPYuXMnRo8ebRDft99+i88++wwxMTFwdXWFj48P8vPzMWzYMFRWVmLevHmwt7fHv//9b9ja2tZrm/fs2YOnn34aPXv2RFJSEm7cuIHo6Gh07ty5zmVffvllfPHFF4iJiUHPnj1x48YN/PDDDzh9+jT69OmDN954A7du3cKVK1fw9ttvAwAcHBwM2li8eDGsra3x2muvQaPR1HoZ6Pfff8df/vIXTJgwAZMnT8Znn32GV155BdbW1nj++efrtb169Yntft988w1GjRoFPz8/JCQk4M6dO3jvvffw+OOP4/Dhw9UGKE+YMAG+vr5ISkrC4cOHsX79enTq1AnLly83KU4iWROJqFnFx8eLAMTJkycblOfm5opKpVJcunSpQfnx48dFKysrqbyqqkp0cnIS586dK4qiKOp0OrFDhw7is88+KyqVSvH27duiKIri6tWrRYVCIf7+++9SW+Xl5QZtV1RUiL169RKHDx9uUA5AVCgU4smTJw3KZ8+eLQIQDx48KJUVFhaKzs7OIgDx4sWLtW57cHCw6OHhIRYXF0tle/bsEQGIXbt2rRZDfHy8NO3s7CzOnDmz1vZHjx5drR1RFMW9e/eKAEQ/P79qfaCft3fvXqlsyJAhIgBx1apVUplGoxGDg4PFTp06iRUVFaIoiuLGjRuNbrexNmuK7eLFiyIAcePGjVKZfj03btyQyo4ePSoqFApx6tSpUpl+X3r++ecN2vzrX/8qdujQodq6iFoyXhIispCXX37ZYHrbtm3Q6XSYMGECrl+/Lv24u7sjICAAe/fuBQAoFAoMGDAA33//PQDg9OnTuHHjBubNmwdRFJGVlQXg3lmXXr16oV27dtI67j8T8vvvv+PWrVsYNGiQdFnlfkOGDEHPnj0Nynbt2oV+/fqhb9++UlnHjh0xZcqUOrc3Ly8POTk5iIqKgrOzs1QeERFRbT3GtGvXDgcPHsS1a9fqrFuTqKioep8NsrKywt///ndp2traGn//+99RWFiI7Oxss2Ooi76fpk2bBhcXF6k8MDAQERER2LVrV7VlHtyXBg0ahBs3bqCkpKTJ4iRqbkxYiCzE19fXYPr8+fMQRREBAQHo2LGjwc/p06dRWFgo1R00aBCys7Nx584d7Nu3Dx4eHujTpw+CgoKky0I//PADBg0aZLCOnTt3ol+/flCr1XBxcUHHjh3xwQcf4NatW3XGBwC//fYbAgICqpU/9NBDdW7vb7/9BgBmL79ixQqcOHEC3t7e6Nu3LxISEvDrr7/Wudz9jG1TTTw9PWFvb29Q1r17dwAw+3kz9aHvJ2N98vDDD+P69evVBml36dLFYLp9+/YA7iWlRK0Fx7AQWciDf+nrdDoIgoD//e9/Ru8YuX/Mw8CBA6HVapGVlYV9+/ZJicmgQYOwb98+nDlzBkVFRQYJy759+/Dkk09i8ODBWLt2LTw8PKBSqbBx40akpKTUGZ+lTZgwAYMGDcL27duxZ88evPXWW1i+fDm2bduGUaNG1auNxt4mQRCMlldVVTXqeupS0x1G4h9jo4haAyYsRDLh7+8PURTh6+sr/SVfk759+8La2hr79u3Dvn37MGfOHAD3niny0UcfISMjQ5rW+/LLL6FWq7F7927Y2NhI5Rs3bqx3jF27dsX58+erlZ89e7ZeywIwe3kA8PDwwIwZMzBjxgwUFhaiT58+WLp0qZSw1JRAmOPatWsoKyszOMty7tw5AJAGverPZNx/xxbw51mS+9U3Nn0/GeuTM2fOwNXVtdqZH6K2gJeEiGTiqaeeglKpRGJiYrW/jEVRxI0bN6RptVqNxx57DJ9++ikuXbpkcIblzp07ePfdd+Hv7w8PDw9pGaVSCUEQDP76z83NxY4dO+od41/+8hccOHAAhw4dksqKioqwZcuWOpf18PBAcHAwNm/ebHAJKj09HadOnap12aqqqmqXrTp16gRPT09oNBqpzN7e3ujlLXNUVlbiww8/lKYrKirw4YcfomPHjggJCQFwL8kEII0n0sf673//u1p79Y3t/n66PxE6ceIE9uzZg7/85S/mbhJRi8YzLEQy4e/vjyVLliAuLg65ubkYP348HB0dcfHiRWzfvh0vvfQSXnvtNan+oEGDsGzZMjg7O6N3794A7h3EH3roIZw9exbTpk0zaH/06NFYvXo1Ro4cieeeew6FhYVYs2YNunXrhmPHjtUrxrlz5+KTTz7ByJEjMWvWLOm25q5du9arjaSkJIwePRoDBw7E888/j5s3b+K9997DI488gtLS0hqXu337Njp37oxnnnkGQUFBcHBwwDfffIOffvoJq1atkuqFhIQgNTUVsbGxeOyxx+Dg4ICxY8fWa9se5OnpieXLlyM3Nxfdu3dHamoqcnJy8O9//1t6Ku0jjzyCfv36IS4uDjdv3oSLiwu2bt2KysrKau2ZEttbb72FUaNGoX///njhhRek25qdnZ2b5f1KRLJkyVuUiNoi/a2oRUVFRud/+eWX4sCBA0V7e3vR3t5e7NGjhzhz5kzx7NmzBvW+/vprEYA4atQog/IXX3xRBCD+5z//qdb2f/7zHzEgIEC0sbERe/ToIW7cuFGK534AaryF+NixY+KQIUNEtVotenl5iYsXLxb/85//1Ou2Zv32Pfzww6KNjY3Ys2dPcdu2bWJUVFSttzVrNBpxzpw5YlBQkOjo6Cja29uLQUFB4tq1aw2WKS0tFZ977jmxXbt2BrdK628z/vzzz6vFU9NtzY888oj4888/i/379xfVarXYtWtX8f3336+2/IULF8Tw8HDRxsZGdHNzE+fPny+mp6dXa7Om2Izd1iyKovjNN9+Ijz/+uGhrays6OTmJY8eOFU+dOmVQp6Z9qabbrYlaMkEUOSqLiIiI5I1jWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcleq3hwnE6nw7Vr1+Do6Nioj+YmIiKipiOKIm7fvg1PT08oFLWfQ2kVCcu1a9fg7e1t6TCIiIjIDJcvX0bnzp1rrdMqEhZHR0cA9zbYycnJ7Ha0Wi327NmDESNGSI/ebovYD+wDgH0AsA/02A/sA6Bp+qCkpATe3t7Scbw2rSJh0V8GcnJyanDCYmdnBycnpza7QwLsB4B9ALAPAPaBHvuBfQA0bR/UZzgHB90SERGR7DFhISIiItljwkJERESy1yrGsNSHKIqorKxEVVVVjXW0Wi2srKxw9+7dWuu1di2xH1QqFZRKpaXDICKiJtImEpaKigrk5eWhvLy81nqiKMLd3R2XL19u089zaYn9IAgCOnfuDAcHB0uHQkRETaDVJyw6nQ4XL16EUqmEp6cnrK2tazwI63Q6lJaWwsHBoc4H2LRmLa0fRFFEUVERrly5goCAAJ5pISJqhVp9wlJRUQGdTgdvb2/Y2dnVWlen06GiogJqtbpFHKibSkvsh44dOyI3NxdarZYJCxFRK9QyjkaNoKUceMk8LeXSFRERmYdHcSIiIpI9JixEREQke61+DEtt3k4/ZzAtiiI0Gg1sbGya5BLDqxHdTao/dOhQfPfddwCAI0eOIDg4uNFjamyCIGD79u0YP358o7SXmZmJYcOGAQDGjRuHHTt2NEq7RETUsvAMi8xNnz4deXl56NWrl6VDMZCQkGA0gcrLy8OoUaMabT0DBgxAXl4eJkyY0GhtEhFRy9Omz7C0BHZ2dnB3d7d0GPXW2LFaW1vD3d0dtra20Gg0jdo2ERG1HDzD0oJkZmZCEARkZGQgNDQUdnZ2GDBgAM6ePWtQ76uvvkKfPn2gVqvh5+eHxMREVFZWSvPPnDmDgQMHQq1Wo2fPnvjmm28gCILB5Zb4+Hj06NEDdnZ28PPzw4IFC6DVagEAmzZtQmJiIo4ePQpBECAIAjZt2gQABu0MGDAAr7/+ukFsRUVFUKlU+P777wEAGo0Gr732Gry8vGBvb4+wsDBkZmY2bscREVGLxzMsLdAbb7yBVatWoWPHjnj55Zfx/PPP48cffwQA7Nu3D1OnTsW7776LQYMG4cKFC3jppZcA3EtCqqqqMH78eHTp0gUHDx7E7du38a9//avaOhwdHbFhwwZ07twZx48fx/Tp0+Ho6Ii5c+di4sSJOHHiBNLS0vDNN98AAJydnau1MWXKFKxYsQLLli2TxgSlpqbC09MTgwYNAgDExMTg1KlT2Lp1Kzw9PbF9+3aMHDkSx48fR0BAQJP0H9Vsbc5a6bOgE+AJT6w/vh6iQrRgVMCM4BkWXT8RWR7PsLRAS5cuxZAhQ9CzZ0/MmzcP+/fvx927dwEAiYmJmDdvHqKiouDn54eIiAgsXrwYH374IQAgPT0dFy5cwMcff4ygoCAMHDgQS5curbaO1157DQMGDICPjw/Gjh2L1157DZ999hkAwNbWFg4ODrCysoK7u7t0yeZBEyZMwLVr1/DDDz9IZSkpKZg8eTIEQcClS5ewceNGfP755xg0aBD8/f3x2muvYeDAgdi4cWNTdB0REbVQPMPSAgUGBkqfPTw8AACFhYXo0qULjh49ih9//NEgCamqqsLdu3dRXl6Os2fPwtvb22CsSd++fautY9u2bfjPf/6DCxcuoLS0FJWVlXBycjIpzo4dO2LEiBHYsmULBg0ahIsXLyIrK0tKno4fP46qqip0725495RGo0GHDh1MWhcREbVuTFhaIJVKJX3WX2rR6XQAgNLSUiQmJuKpp56qtpxara5X+1lZWXjppZeQkJCAkSNHwtnZGVu3bsWqVatMjnXKlCn45z//iffeew8pKSno3bs3evfuLcWqVCqRnZ1d7XH6fIkhERHdz6xLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoM60adOkwZv6n5EjR5oTWpvXp08fnD17Ft26dav2o1Ao8NBDD+Hy5csoKCiQlvnpp58M2sjKyoK3tzfmz5+P0NBQBAQE4LfffjOoY21tjaqqqjrjGTduHO7evYu0tDSkpKRgypQp0rxHH30UVVVVKCwsrBZrS7ozioiImp7JZ1hSU1MRGxuLdevWISwsDMnJyYiMjMTZs2fRqVOnavVdXFzwxhtvoEePHrC2tsbOnTsRHR2NTp06ITIyUqo3cuRIg3ELNjY2Zm5S27Zw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCSIiIuDv74+oqCisWLECt2/fxptvvgngz7M13bp1w5UrV7B161aEhYXh66+/xvbt2w3W4+Pjg4sXLyInJwedO3eGo6Oj0e/M3t4e48ePx4IFC3D69GlMnjxZmte9e3dMmTIFU6dOxapVq/Doo4+iqKgIGRkZCAwMxOjRo5uwp4iIqCUxOWFZvXo1pk+fjujoaADAunXr8PXXX2PDhg2YN29etfpDhw41mJ41axY2b96MH374wSBhsbGxafa/qh988qxOp0NJSQmcnJxa7MsSIyMjsXPnTixatAjLly+HSqVCjx498OKLLwIAlEolduzYgRdffBGPPfYY/Pz88NZbb2Hs2LHSJaMnn3wSr7zyCv75z39Co9Fg9OjRWLBgARISEqT1PP3009i2bRuGDRuG4uJibNy4EdOmTTMa05QpU/CXv/wFgwcPRpcuXQzmbdy4EUuWLMG//vUvXL16Fa6urujXrx/GjBnTJP1DREQtk0kJS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUG8zIzM9GpUye0b98ew4cPx5IlS2oceKnRaAweIlZSUgIA0Gq10rNC9LRaLURRhE6nk8Z51Baf/t+66jaX+2MZPHiwdBlGXxYYGFitLCIiAhEREdXa0s/v3r279BwUANIt0X5+ftDpdBBFEYsWLcLbb79t8IqCf/7zn1IbKpVKumvo/vYfjAW4l0QZKwfuJVDx8fGIj4+vMV59P9T2vejj1mq11cbDmEO/Hz24P7V2gk6o9vn+Mkux1PfQVveDB7Ef2AdA0/SBKW0Jov4oXQ/Xrl2Dl5cX9u/fj/79+0vlc+fOxXfffYeDBw8aXe7WrVvw8vKCRqOBUqnE2rVr8fzzz0vzt27dCjs7O/j6+uLChQuYP38+HBwckJWVZfTgk5CQgMTExGrlKSkpsLOzMyjT33rr7e0Na2vr+m6qLIwZMwaHDh2CtbU1du/ejUceeaRR2t25cyfs7e3h7++PX3/9FXFxcXB2dkZaWlqjtN+Y9u/fjwkTJkCj0Uh3HBlTUVGBy5cvIz8/3+AheUREJF/l5eV47rnncOvWrTrvRG2Wu4QcHR2Rk5OD0tJSZGRkIDY2Fn5+ftLlokmTJkl1e/fujcDAQPj7+yMzMxNPPPFEtfbi4uIQGxsrTZeUlMDb2xsjRoyotsF3797F5cuX4eDgUOddMqIo4vbt23B0dGySlx+a6tNPP8WdO3cAAF26dGm0hKuyshKvv/46Ll26BFdXVzzxxBNYuXKl1Hdy6ochQ4bg8OHDAO7dOVTTDn337l3Y2tpi8ODB9b4bqjZarRbp6emIiIgwuCurtVt/fL30WdAJ8LjqgTyvPIs/OO7F3i9aZL1tdT94EPuBfQA0TR/or5DUh0kJi6urK5RKpcEdJgBQUFBQ6/gThUKBbt26AQCCg4Nx+vRpJCUlVRvfoufn5wdXV1f88ssvRhMWGxsbowM8VSpVtU6sqqqCIAhQKBR1jkvRX27Q17c0b2/vJml32rRpNY43AeTVD/b29tWe02KMQqGAIAhG94GGaOz25M5YYiIqRIsnLJb+DtraflAT9gP7AGjcPjClHZOORtbW1ggJCUFGRoZUptPpkJGRYXCJqC46na7WF9lduXIFN27ckB6KRkRERG2byZeEYmNjERUVhdDQUPTt2xfJyckoKyuT7hqaOnUqvLy8kJSUBABISkpCaGgo/P39odFosGvXLnzyySf44IMPAPz5oLOnn34a7u7uuHDhAubOnYtu3boZ3EVEREREbZfJCcvEiRNRVFSEhQsXIj8/H8HBwUhLS4ObmxsA4NKlSwaXEcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJAO7dJXLs2DFs3rwZxcXF8PT0xIgRI7B48WI+i4WIiIgAmDnoNiYmBjExMUbnZWZmGkwvWbIES5YsqbEtW1tb7N6925wwiIiIqI2w/MhSIiIiojowYSEiIiLZa9tva96bZDApiCLUGg0EGxugKZ4/Miyu7jr3GTp0KL777jsAwJEjRxAcHNz4MTWDTZs2Yfbs2SguLpam9YO0Z82aheTkZMsFR0RELQLPsMjc9OnTkZeXh169ejXbOjMzM9G+fXspwWhsEydORF5enkm3whMRUdvWts+wtAB2dnbN/lLI+qqoqDDr6bu2trawtbVtca9KICIiy+EZlhYkMzMTgiAgIyMDoaGhsLOzw4ABA3D27FmDel999RX69OkDtVoNPz8/JCYmSu/Xyc3NhSAIyMnJkeoXFxdDEARkZmYiNzdXerpwhw4dIAiC9FTcoUOHIiYmBrNnz4arq6v0nJzVq1ejd+/esLe3h7e3N2bMmIHS0tKm7xAiImozmLC0QG+88QZWrVqFn3/+GVZWVgYvkty3bx+mTp2KWbNm4dSpU/jwww+xadMmLF26tF5te3t74/PPPwcAnD59Gnl5eXjnnXek+Zs3b4a1tTV+/PFHrFu3DsC9x+K/++67OHnyJDZv3oxvv/0Wc+fObcQtJiKito6XhFqgpUuXYsiQIQCAefPmYfTo0bh79y7UajUSExMxb948REVFAbj3XqbFixdj7ty5iI+Pr7NtpVIJFxcXAECnTp2kz3oBAQFYsWKFQdns2bOlzz4+PliyZAlefvllrF27tiGbSUREJGHC0gIFBgZKn/XvWyosLESXLl1w9OhR/PjjjwZnVKqqqnD37l2Ul5c3eN0hISHVyr755hskJSXhzJkzKCkpQWVlpbQ+Ozu7Bq+TiIiICUsLdP/bLYU/br/Wv2FZ/26mp556qtpyarVaem2CKP759l2tVlvvddvb2xtM5+bmYsyYMXjllVewdOlSuLi44IcffsALL7yAiooKJixERNQomLC0Mn369MHZs2fRrVs3o/M7duwIAMjLy8Ojjz4KAAYDcAFId+9UVVXVub7s7GzodDqsWrVKSoY+++wzc8MnIiIyiglLK7Nw4UKMGTMGXbp0wTPPPAOFQoGjR4/ixIkTWLJkCWxtbdGvXz8sW7YMvr6+KCwsxJtvvmnQRteuXSEIAnbu3IkxY8bA1tYWDg4ORtfXrVs3aLVavPfeexg7dqzBYFwiIqLG0rYTlgeePCvqdLhbUgJrJycIipZ5A1VkZCR27tyJRYsWYfny5VCpVOjRowdefPFFqc6GDRvwwgsvICQkBA899BBWrFiBESNGSPO9vLwQFxeH+fPn44UXXsDUqVOxadMmo+sLCgrC6tWrsXz5csTFxWHw4MFISkrC1KlTm3pTiYioDWnbCUsLM3ToUIOxJwAQHBxcrSwyMlJ6RooxDz/8MPbv329Q9mAbc+bMweLFi6XLPED1N3Hrvfrqq3j11VcNyv72t79Jn6dNmyY9y4WIiMgcLfM0Qhuydu1aODg44Pjx45YOpdFs2bIFDg4O2Ldvn6VDISKiFoJnWGRsy5YtuHPnDgCgS5cuFo6m8Tz55JMICwsDALRr186ywRARUYvAhEXGvLy8LB1Ck3B0dISjo6OlwyAiohaEl4SIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPdwkREdXH3iRLR2DcA0/spnoy9fsUFQB6APtWA4KuSUICwO+zFm06YVmbs9ZgWhRFaDQa2NjYSG9BbkwzgmeYVH/o0KH47rvvAABHjhxBcHBwo8dkbJ1BQUFITEyssc6mTZswe/ZsFBcXN9p6p02bhs2bNwMAtm/fjvHjxzda20RE1PLxkpDMTZ8+HXl5eejVq1ezrG/btm1YtGiRNO3j44Pk5GSDOhMnTsS5c+cadb3vvPMO8vLyGrVNIiJqPdr0GZaWwM7ODu7u7s22PhcXF+h0OpSUlNRYx9bWFra2to26XmdnZzg7Ozdqm0RE1HrwDEsLkpmZCUEQ8PXXXyMwMBBqtRr9+vXDiRMnDOp9+eWXeOSRR2BjYwMfHx+sWrXKYP7atWsREBAAtVoNNzc3PPPMM9K8oUOHSi8yHD58OH777Te8+uqrEARBuky2adMm6ZH6586dgyAIOHPmjME63n77bfj7+0vTJ06cwKhRo+Dg4AA3Nzf87W9/w/Xr1xutb4iIqHVjwtICzZkzB6tWrcJPP/2Ejh07YuzYsdBqtQCA7OxsTJgwAZMmTcLx48eRkJCABQsWYNOmTQCAn3/+Gf/85z+xaNEinD17FmlpaRg8eLDR9XzxxRfo3LkzFi1ahLy8PKOXbLp3747Q0FBs2bLFoHzLli147rnnAADFxcUYPnw4Hn30Ufz8889IS0tDQUEBJkyY0Ii9QkRErRkvCbVA8fHxiIiIAABs3rwZnTt3xvbt2zFhwgSsXr0aTzzxBBYsWADgXkJx6tQpvPXWW5g2bRouXboEe3t7jBkzBo6OjujatSseffRRo+txcXGBUqmEo6NjrZelpkyZgvfffx+LFy8GcO+sS3Z2Nv773/8CAN5//308+uij+L//+z9pmQ0bNsDb2xvnzp1D9+7dG6VfiIio9eIZlhaof//+0mcXFxc89NBDOH36NADg9OnTePzxxw3qP/744zh//jyqqqoQERGBrl27ws/PD3/729+wZcsWlJeXNyieSZMmITc3FwcOHABw7+xKnz590KNHDwDA0aNHsXfvXjg4OEg/+nkXLlxo0LqJiKhtYMLSxjg6OuLw4cP49NNP4eHhgYULFyIoKKhBtyi7u7tj+PDhSElJAQCkpKRgypQp0vzS0lKMHTsWOTk5Bj/nz5+v8XIUERHR/ZiwtED6MxkA8Pvvv+PcuXN4+OGHAQAPP/wwfvzxR4P6P/74I7p37w6lUgkAsLKyQnh4OFasWIFjx44hNzcX3377rdF1WVtbo6qqqs6YpkyZgtTUVGRlZeHXX3/FpEmTpHl9+vTByZMn4ePjg27duhn82Nvbm7z9RETU9jBhaYEWLVqEjIwMnDhxAtOmTYOrq6v0oLV//etfyMjIwOLFi3Hu3Dls3rwZ77//Pl577TUAwM6dO/Huu+8iJycHv/32Gz7++GPodDo89NBDRtfl4+OD77//HlevXq31rp6nnnoKt2/fxiuvvIJhw4bB09NTmjdz5kzcvHkTkydPxk8//YQLFy5g9+7diI6OrlcyRERE1KYH3T745Fn980ecnJygUMg3l1u2bBlmzZqF8+fPIzg4GP/v//0/WFtbA7h3NuOzzz7DwoULsXjxYnh4eGDRokWYNm0aAKBdu3bYtm0bEhIScPfuXQQEBODTTz/FI488YnRdixYtwt///nf4+/tDo9FAFEWj9RwdHTF27Fh89tln2LBhg8E8T09P/Pjjj3j99dcxYsQIaDQadO3aFSNHjpR1PxMRkXy06YSlpRo4cGC1Z6/c7+mnn8bTTz9d47KZmZk1LpuZmWnw4Lh+/frh6NGjBnWmTZsmJUD3S01NRWpqqtF2AwICsG3bthrXS0REVBv+eStza9euhYODA44fP27pUJrUyy+/DAcHB0uHQUREMsUzLDK2ZcsW3LlzBwDQpUsX7N+/38IRNZ1FixZJ42w8PDwsHA0REckNExYZ8/LyMpgeOnRojWNIWrpOnTqhU6dOlg6DiIhkyqxLQmvWrIGPjw/UajXCwsJw6NChGutu27YNoaGhaNeuHezt7REcHIxPPvnEoI4oili4cCE8PDxga2uL8PBwnD9/3pzQiIiIqBUyOWFJTU1FbGws4uPjcfjwYQQFBSEyMhKFhYVG67u4uOCNN95AVlYWjh07hujoaERHR2P37t1SnRUrVuDdd9/FunXrcPDgQdjb2yMyMhJ37941f8se0FrPTNA9/H6JiFo3kxOW1atXY/r06YiOjkbPnj2xbt062NnZVbuVVW/o0KH461//iocffhj+/v6YNWsWAgMD8cMPPwC4d6BJTk7Gm2++iXHjxiEwMBAff/wxrl27hh07djRo4wBApVIBQIMfP0/yVlFRAQDSw/GIiKh1MWkMS0VFBbKzsxEXFyeVKRQKhIeHIysrq87lRVHEt99+i7Nnz2L58uUAgIsXLyI/Px/h4eFSPWdnZ4SFhSErK8vgial6Go0GGo1GmtbfgqvVaqW3Ft/P0dERBQUF0Ol0sLOzgyAINcZXUVGBO3fu1FinLWhp/aDT6VBYWAi1Wg1RFI3uA6bSt9EYbbUkgk6o9vn+Mkux1PdgsB+IMr2pshn6plX+Ppj4fWr/qK9t6v1Axn3cFPuBKW2ZlLBcv34dVVVVcHNzMyh3c3PDmTNnalzu1q1b8PLygkajgVKpxNq1a6W3Defn50ttPNimft6DkpKSkJiYWK18z549sLOzM7qMo6MjysrK+KCyVkqr1aKoqAjHjh1r1HbT09MbtT2584RntTKPq5a/a2vX5V0WXf+9/aCHRWOo0a7m65vW9ftg3veZXtrEb5dvxu/TXI25H5hy9aNZ7hJydHRETk4OSktLkZGRgdjYWPj5+WHo0KFmtRcXF4fY2FhpuqSkBN7e3hgxYgScnJxqXK6qqgqVlZU1jneorKzE/v37MWDAAFhZtd0bqFpaPwiCAJVK1ajJqFarRXp6OiIiIqTLim3B+uPrpc+CToDHVQ/keeVBVFh2jNCLvV+0yHoN9oMD71kkhjoNiq27TgO1yt+HfatNqq4VFUgv7Y4Ih3NQCbomCgrN8n2aqyn2A/0Vkvow6Wjk6uoKpVKJgoICg/KCggK4u7vXuJxCoUC3bt0AAMHBwTh9+jSSkpIwdOhQabmCggKD528UFBQgODjYaHs2NjawsbGpVq5SqWrtxLo6WKvVorKyEg4ODq3nl9IM7Ic/1bVPtTbGEhNRIVo8YbH0d6BSqZr2INUQzdg3rer3wczvUyXomnZfaAH925j7gSntmPQnqbW1NUJCQpCRkSGV6XQ6ZGRkoH///vVuR6fTSWNQfH194e7ubtBmSUkJDh48aFKbRERE1HqZfL4/NjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDfeJDQ0VHp53q5du/DJJ5/ggw8+AHDvdP7s2bOxZMkSBAQEwNfXFwsWLICnp6f0BmIiIiJq20xOWCZOnIiioiIsXLgQ+fn5CA4ORlpamjRo9tKlSwZjCcrKyjBjxgxcuXIFtra26NGjB/773/9i4sSJUp25c+eirKwML730EoqLizFw4ECkpaVBrVY3wiYSERFRS2fWiMqYmBjExMQYnffgm4CXLFmCJUuW1NqeIAhYtGgRFi1aZE44RERE1MrxHl8iIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2zHqXEFFr8Xb6OaPlglgFXwBr9v4CUVA2a0yvRnRv1vUREbUEPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0rSwdARIbeTj9nsXUfLrkhfbaCAuOtPC0WCxHR/XiGhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9sxKWNasWQMfHx+o1WqEhYXh0KFDNdb96KOPMGjQILRv3x7t27dHeHh4tfrTpk2DIAgGPyNHjjQnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqQb2RI0ciLy9P+vn000/N2yIiIiJqdUxOWFavXo3p06cjOjoaPXv2xLp162BnZ4cNGzYYrb9lyxbMmDEDwcHB6NGjB9avXw+dToeMjAyDejY2NnB3d5d+2rdvb94WERERUatj0oPjKioqkJ2djbi4OKlMoVAgPDwcWVlZ9WqjvLwcWq0WLi4uBuWZmZno1KkT2rdvj+HDh2PJkiXo0KGD0TY0Gg00Go00XVJSAgDQarXQarWmbJIB/bINaaM1aEv9IIhVtZbXNL+1srrvbxj9Z0EnWCociaX2RYPfBVGmQ/6aoW9a5f8JJn6f2j/qa5t6P5BxHzfFfmBKW4IoimJ9K1+7dg1eXl7Yv38/+vfvL5XPnTsX3333HQ4ePFhnGzNmzMDu3btx8uRJqNVqAMDWrVthZ2cHX19fXLhwAfPnz4eDgwOysrKgVCqrtZGQkIDExMRq5SkpKbCzs6vv5hAREZEFlZeX47nnnsOtW7fg5ORUa91mfTT/smXLsHXrVmRmZkrJCgBMmjRJ+ty7d28EBgbC398fmZmZeOKJJ6q1ExcXh9jYWGm6pKREGhtT1wbXRqvVIj09HREREVCpVGa309K1pX5Ys/cXo+WCWAWfuxeQq/aHKFRPmluro7e3SZ+toMAYq57I88qDqKj33zVN4sXeL1pkvQa/Cwfes0gMdRoUW3edBmqV/yfsW21Sda2oQHppd0Q4nINK0DVRUPKm7fePRt8P9FdI6sOkhMXV1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmX45ptvEBgYWGtdPz8/uLq64pdffjGasNjY2MDGxqZauUqlapRObKx2Wrq20A91JSOioGxTCUslqv9HLCpEiycslt4PVSqVfA9Szdg3rer/BDO/T5Wgk+++0NT++O4bcz8wpR2TLsZZW1sjJCTEYMCsfgDt/ZeIHrRixQosXrwYaWlpCA0NrXM9V65cwY0bN+Dh4WFKeERERNRKmTx6KDY2Fh999BE2b96M06dP45VXXkFZWRmio6MBAFOnTjUYlLt8+XIsWLAAGzZsgI+PD/Lz85Gfn4/S0lIAQGlpKebMmYMDBw4gNzcXGRkZGDduHLp164bIyMhG2kwiIiJqyUwewzJx4kQUFRVh4cKFyM/PR3BwMNLS0uDm5gYAuHTpEhSKP/OgDz74ABUVFXjmmWcM2omPj0dCQgKUSiWOHTuGzZs3o7i4GJ6enhgxYgQWL15s9LIPERERtT1mDbqNiYlBTEyM0XmZmZkG07m5ubW2ZWtri927d5sTBhEREbURMn2wABEREdGfmLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItmzsnQA1LqszVlr6RCqmRE8w9IhUCu0tviYpUO4pxl+5wSdAE94Yv3x9RAVosnL83eQGgPPsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI93tZMzSrrwo1mX6em6Fyzr7M1OXTxJiqhs2gM93+Hr0Z0t2AkRGQpPMNCREREsseEhYiIiGSPl4SIiKhpXNx379/fb1k2DmoVeIaFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZM+shGXNmjXw8fGBWq1GWFgYDh06VGPdjz76CIMGDUL79u3Rvn17hIeHV6sviiIWLlwIDw8P2NraIjw8HOfPnzcnNCIiImqFTE5YUlNTERsbi/j4eBw+fBhBQUGIjIxEYWGh0fqZmZmYPHky9u7di6ysLHh7e2PEiBG4evWqVGfFihV49913sW7dOhw8eBD29vaIjIzE3bt3zd8yIiIiajVMfpfQ6tWrMX36dERHRwMA1q1bh6+//hobNmzAvHnzqtXfsmWLwfT69evx5ZdfIiMjA1OnToUoikhOTsabb76JcePGAQA+/vhjuLm5YceOHZg0aVK1NjUaDTQajTRdUlICANBqtdBqtaZukkS/bEPaaA0a0g+CTqh1vpUFrkIKYpXZy5izbEt2//ej/2yJ7+xB938Pzfn7afC7IBr2gyCXV7HV8TvXGPS/13X9fld3r4+0ouX3oYbSb0Nr2BZzNcUx0pS2BFEUxfpWrqiogJ2dHb744guMHz9eKo+KikJxcTG++uqrOtu4ffs2OnXqhM8//xxjxozBr7/+Cn9/fxw5cgTBwcFSvSFDhiA4OBjvvPNOtTYSEhKQmJhYrTwlJQV2dnb13RwiIiKyoPLycjz33HO4desWnJycaq1r0p8I169fR1VVFdzc3AzK3dzccObMmXq18frrr8PT0xPh4eEAgPz8fKmNB9vUz3tQXFwcYmNjpemSkhLpUlNdG1wbrVaL9PR0REREQKVSmd1OS9eQflh/fH2t8w9dvNmQ0MwS5PiUycsIYhV87l5ArtofoqBsgqjk6ejtbdJnKygwxqondlaeQiV0FozK8DucOaxbs63X4HfhwHsG89bfOtFscdSq64AmX4WgE+Bx1QN5XnkQFfX+Gxf4bT8A4EXnXk0UWfPRigqkl3ZHhMM5qATL/j5YirbfPxr9GKm/QlIfzXpOc9myZdi6dSsyMzOhVqvNbsfGxgY2NjbVylUqVaN0YmO109KZ0w91/WdmiQNfQxIOUVC2qYTF2PdTCZ3FE5b7vwNL/G6qVKpqBykRlc0eh1GmJBANJCpE0xKWP/qoNR3gVYKuVW2PSf743WvMY6Qp7Zh0Mc7V1RVKpRIFBQUG5QUFBXB3d6912ZUrV2LZsmXYs2cPAgMDpXL9cua0SURERG2DSQmLtbU1QkJCkJGRIZXpdDpkZGSgf//+NS63YsUKLF68GGlpaQgNDTWY5+vrC3d3d4M2S0pKcPDgwVrbJCIiorbD5EtCsbGxiIqKQmhoKPr27Yvk5GSUlZVJdw1NnToVXl5eSEpKAgAsX74cCxcuREpKCnx8fKRxKQ4ODnBwcIAgCJg9ezaWLFmCgIAA+Pr6YsGCBfD09DQY2EtERERtl8kJy8SJE1FUVISFCxciPz8fwcHBSEtLkwbNXrp0CQrFnyduPvjgA1RUVOCZZ54xaCc+Ph4JCQkAgLlz56KsrAwvvfQSiouLMXDgQKSlpTVonAsRERG1HmYNuo2JiUFMTIzReZmZmQbTubm5dbYnCAIWLVqERYsWmRMOERERtXJt9wk4RERE1GIwYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnllvayYiak6HS1Klz2tzOjTbegWdAE94Yv3x9RCLjzXbeomoOp5hISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPd4lRETUAJeL71h0/Vcu3GjydVhBgfFWnjh08SYqoav3cp1L/uibdk0TF7UtPMNCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGTPrIRlzZo18PHxgVqtRlhYGA4dOlRj3ZMnT+Lpp5+Gj48PBEFAcnJytToJCQkQBMHgp0ePHuaERkRERK2QyQlLamoqYmNjER8fj8OHDyMoKAiRkZEoLCw0Wr+8vBx+fn5YtmwZ3N3da2z3kUceQV5envTzww8/mBoaERERtVJWpi6wevVqTJ8+HdHR0QCAdevW4euvv8aGDRswb968avUfe+wxPPbYYwBgdL4UiJVVrQlNS/B2+jlLh1DNqxHdLR0CERFRg5mUsFRUVCA7OxtxcXFSmUKhQHh4OLKyshoUyPnz5+Hp6Qm1Wo3+/fsjKSkJXbp0MVpXo9FAo9FI0yUlJQAArVYLrVZrdgz6Zc1tQxCrzF53UzFnWxrSD4JOqHW+lQWGTZnzveiXkeN32pTu/370ny3xndWmrn2sKdZ171/j/10qoWq2eIxpju/H3H1B3zdaUV77kDn029AatsVcDT1G1tZmfQiiKIr1rXzt2jV4eXlh//796N+/v1Q+d+5cfPfddzh48GCty/v4+GD27NmYPXu2Qfn//vc/lJaW4qGHHkJeXh4SExNx9epVnDhxAo6OjtXaSUhIQGJiYrXylJQU2NnZ1XdziIiIyILKy8vx3HPP4datW3Bycqq1rsmXhJrCqFGjpM+BgYEICwtD165d8dlnn+GFF16oVj8uLg6xsbHSdElJCby9vTFixIg6N7g2Wq0W6enpiIiIgEpl+l9Na/b+Yva6m8rMYd1MXqYh/bD++Ppa5x+6eNPkeBoqyPEpk5cRxCr43L2AXLU/REHZBFHJ09Hb26TPVlBgjFVP7Kw8hUroLBiVob6+Ls22LkEnwOOqB/K88iBe/tFonavFd5stHqPrdwpu8nWYuy94leQAAOK7hjZRZM1HKyqQXtodEQ7noBLk8/vQnLT9/tGgY6Qx+isk9WFSwuLq6gqlUomCggKD8oKCgkYdf9KuXTt0794dv/xiPAGwsbGBjY1NtXKVStUonWhuO3I8sDWkP8zpB1FR+wk7Sxz4GvK9iIJSlt9rUzH2/VRCJ6uEpa59rKnWKaLS6LwqNN7pcXM053dj6r6g75vWdIBXCbpWtT0m+eN40FjHWn1b9WXSxThra2uEhIQgIyNDKtPpdMjIyDC4RNRQpaWluHDhAjw8PBqtTSIiImq5TL4kFBsbi6ioKISGhqJv375ITk5GWVmZdNfQ1KlT4eXlhaSkJAD3BuqeOnVK+nz16lXk5OTAwcEB3brdu1zx2muvYezYsejatSuuXbuG+Ph4KJVKTJ48ubG2k4iIiFowkxOWiRMnoqioCAsXLkR+fj6Cg4ORlpYGNzc3AMClS5egUPx54ubatWt49NFHpemVK1di5cqVGDJkCDIzMwEAV65cweTJk3Hjxg107NgRAwcOxIEDB9CxY8cGbh4RERG1BmYNuo2JiUFMTIzRefokRM/Hxwd13Yi0detWc8IgIiKiNkIWdwkRmatzSXaddfoV3zK5XZ2gxPUOA/DYlU1QtMJnsRzo8pKlQyAiMknbfQIOERERtRhMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREcken8NCrd7/pzD9LdpKqPAYBuB/wq+oEprm5XZP6kx/kzYRUVvFMyxEREQke0xYiIiISPaYsBAREZHscQwLERE1qaxfb1g6BAP9/TpYOgQyA8+wEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkexZWToAorbq/1P8YrF1XylJtdi6qe2x5L5uzJHiq5jRLtDSYZCJeIaFiIiIZI8JCxEREckeExYiIiKSPY5hIaIWJevCjWZblxUUGG/liUMXb8K95E6zrZeIquMZFiIiIpI9JixEREQke0xYiIiISPbMSljWrFkDHx8fqNVqhIWF4dChQzXWPXnyJJ5++mn4+PhAEAQkJyc3uE0iIiJqW0xOWFJTUxEbG4v4+HgcPnwYQUFBiIyMRGFhodH65eXl8PPzw7Jly+Du7t4obRIREVHbYnLCsnr1akyfPh3R0dHo2bMn1q1bBzs7O2zYsMFo/cceewxvvfUWJk2aBBsbm0Zpk4iIiNoWk25rrqioQHZ2NuLi4qQyhUKB8PBwZGVlmRWAOW1qNBpoNBppuqSkBACg1Wqh1WrNikO//P3/mkoQq8xed1MxZ1sa0g+CTqh1vlUjD5tSQtWo7ekp/mhX0UTtW1p9vgd9ncb+zlqS+/ugqfa1hmqO78fcfUGufSbAClrRtG3R1zd1udakocfI2tqsD5MSluvXr6Oqqgpubm4G5W5ubjhz5owpTTWozaSkJCQmJlYr37NnD+zs7MyK437p6elmLefb4DU3vl27zpm9rDn94AnPWuePt6p9vslcejVuew8IcZnYpO1bymMm1B1j1bPJ4mgpxlj1BFzk2Q+mfJcNZfK+0MS/nw2x67Z5y6WXdm/cQFqSP44J5h4jjSkvL6933Rb54Li4uDjExsZK0yUlJfD29saIESPg5ORkdrtarRbp6emIiIiASmX6XwZr9srrBV8AMHNYN5OXaUg/rD++vtb5hy7eNDme2niV5DRqe3oKqBDiMhHZN1OhQ+P9NSEXV52C66xjBQXGWPXEzspTqISu6YOSofv7wK3ksKXDMao+32VDmbsvNNXvZ0N5tVPjRWfTkimtqEB6aXdEOJyDSmibvw/afv9o0DHSGP0VkvowKWFxdXWFUqlEQUGBQXlBQUGNA2qbok0bGxuj42FUKlWjdKK57YiCssHrbmwN6Q9z+kFUiLXOb+wDX1UTJxM6aJt8HZZgyvdQCV2bTVj0KqGT7X7QnN+NqfuCXPtMhJXZSYdK0LXZhAV/HA8a61irb6u+TLoYZ21tjZCQEGRkZEhlOp0OGRkZ6N+/vylNNWmbRERE1LqYfEkoNjYWUVFRCA0NRd++fZGcnIyysjJER0cDAKZOnQovLy8kJSUBuDeo9tSpU9Lnq1evIicnBw4ODujWrVu92iQiIqK2zeSEZeLEiSgqKsLChQuRn5+P4OBgpKWlSYNmL126BIXizxM3165dw6OPPipNr1y5EitXrsSQIUOQmZlZrzaJiIiobTNr0G1MTAxiYmKMztMnIXo+Pj4QxdrHNdTVJhEREbVtbfeGciIiImoxmLAQERGR7DFhISIiItlrkQ+Oa25rc9bWq97hkhtNHInp1uZ0MHkZQSfAE55Yf3x9nc9VISIiag5MWIhINjqXZFs6BANKqACXXvAqyYH83hR2T3P0mWE/yPNhcNT68ZIQERERyR7PsLRyWRdMv0xlBQXGW3ni0MWbbf6R7EREJA88w0JERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JCxEREckeExYiIiKSPSYsREREJHtMWIiIiEj2rCwdABE1v84l2XXWUUIFuPSCV0kOqqBthqiIiGrGMyxEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9JixEREQke0xYiIiISPaYsBAREZHsMWEhIiIi2WPCQkRERLLHhIWIiIhkjwkLERERyR4TFiIiIpI9sxKWNWvWwMfHB2q1GmFhYTh06FCt9T///HP06NEDarUavXv3xq5duwzmT5s2DYIgGPyMHDnSnNCIiIioFbIydYHU1FTExsZi3bp1CAsLQ3JyMiIjI3H27Fl06tSpWv39+/dj8uTJSEpKwpgxY5CSkoLx48fj8OHD6NWrl1Rv5MiR2LhxozRtY2Nj5iZRU+hckm3pEIiIqA0z+QzL6tWrMX36dERHR6Nnz55Yt24d7OzssGHDBqP133nnHYwcORJz5szBww8/jMWLF6NPnz54//33DerZ2NjA3d1d+mnfvr15W0REREStjklnWCoqKpCdnY24uDipTKFQIDw8HFlZWUaXycrKQmxsrEFZZGQkduzYYVCWmZmJTp06oX379hg+fDiWLFmCDh06GG1To9FAo9FI0yUlJQAArVYLrVZryiYZ0C/7YBuCTqjX8latZEiQfjvu3x4lVJYKxyIUf2yvoo1t9/3YB+wDvdbWDwKsoBVN+/9aX9/U5VqTmo6RjdFmfZiUsFy/fh1VVVVwc3MzKHdzc8OZM2eMLpOfn2+0fn5+vjQ9cuRIPPXUU/D19cWFCxcwf/58jBo1CllZWVAqldXaTEpKQmJiYrXyPXv2wM7OzpRNMio9Pd1g2hOe9VpuvFX96rUUY6x6/jnh0qvmiq1YiMtES4dgcewD9oFea+qHXbfNWy69tHvjBtKS/HFsfPAY2RDl5eX1rmvyGJamMGnSJOlz7969ERgYCH9/f2RmZuKJJ56oVj8uLs7grE1JSQm8vb0xYsQIODk5mR2HVqtFeno6IiIioFL9+ZfE+uPr67X8oYs3zV63nFhBgTFWPbGz8hQqoQMAeJXkWDaoZqaACiEuE5F9MxU6NN5fEy0J+4B9oNfa+sGrnRovOpv2R5hWVCC9tDsiHM5BJeiaKDJ50/b7h9FjZEPor5DUh0kJi6urK5RKJQoKCgzKCwoK4O7ubnQZd3d3k+oDgJ+fH1xdXfHLL78YTVhsbGyMDspVqVSN0okPtiMqxHotpz+4txaV0EnbVNUK/pMyhw7aNrvteuwD9oFea+kHEVZmJx0qQddmExb8cVxsrGOtvq36MulinLW1NUJCQpCRkSGV6XQ6ZGRkoH///kaX6d+/v0F94N7ppJrqA8CVK1dw48YNeHh4mBIeERERtVImjx6KjY3FRx99hM2bN+P06dN45ZVXUFZWhujoaADA1KlTDQblzpo1C2lpaVi1ahXOnDmDhIQE/Pzzz4iJiQEAlJaWYs6cOThw4AByc3ORkZGBcePGoVu3boiMjGykzSQiIqKWzOQxLBMnTkRRUREWLlyI/Px8BAcHIy0tTRpYe+nSJSgUf+ZBAwYMQEpKCt58803Mnz8fAQEB2LFjh/QMFqVSiWPHjmHz5s0oLi6Gp6cnRowYgcWLF/NZLERERATAzEG3MTEx0hmSB2VmZlYre/bZZ/Hss88arW9ra4vdu3ebEwYRERG1EW33hnIiIiJqMZiwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJnlkJy5o1a+Dj4wO1Wo2wsDAcOnSo1vqff/45evToAbVajd69e2PXrl0G80VRxMKFC+Hh4QFbW1uEh4fj/Pnz5oRGRERErZDJCUtqaipiY2MRHx+Pw4cPIygoCJGRkSgsLDRaf//+/Zg8eTJeeOEFHDlyBOPHj8f48eNx4sQJqc6KFSvw7rvvYt26dTh48CDs7e0RGRmJu3fvmr9lRERE1GqYnLCsXr0a06dPR3R0NHr27Il169bBzs4OGzZsMFr/nXfewciRIzFnzhw8/PDDWLx4Mfr06YP3338fwL2zK8nJyXjzzTcxbtw4BAYG4uOPP8a1a9ewY8eOBm0cERERtQ5WplSuqKhAdnY24uLipDKFQoHw8HBkZWUZXSYrKwuxsbEGZZGRkVIycvHiReTn5yM8PFya7+zsjLCwMGRlZWHSpEnV2tRoNNBoNNL0rVu3AAA3b96EVqs1ZZMMaLValJeX48aNG1CpVFL53ZL6nenRlVeYvW450UGBcqty6CoroIMOAFB5x8JBNTMdgPLycmjv4I8eaHvYB+wDvdbWD3etdbhhZdr/11pRce/4IFRAJbSGXjCd9sYNo8fIhrh9+zaAeycv6mJSwnL9+nVUVVXBzc3NoNzNzQ1nzpwxukx+fr7R+vn5+dJ8fVlNdR6UlJSExMTEauW+vr712xCq08eWDkAWvrB0ADLAPmAf6LWufviXpQNokRKarOXbt2/D2dm51jomJSxyERcXZ3DWRqfT4ebNm+jQoQMEQTC73ZKSEnh7e+Py5ctwcnJqjFBbJPYD+wBgHwDsAz32A/sAaJo+EEURt2/fhqenZ511TUpYXF1doVQqUVBQYFBeUFAAd3d3o8u4u7vXWl//b0FBATw8PAzqBAcHG23TxsYGNjY2BmXt2rUzZVNq5eTk1GZ3yPuxH9gHAPsAYB/osR/YB0Dj90FdZ1b0TBp0a21tjZCQEGRkZEhlOp0OGRkZ6N+/v9Fl+vfvb1AfANLT06X6vr6+cHd3N6hTUlKCgwcP1tgmERERtS0mXxKKjY1FVFQUQkND0bdvXyQnJ6OsrAzR0dEAgKlTp8LLywtJSUkAgFmzZmHIkCFYtWoVRo8eja1bt+Lnn3/Gv//9bwCAIAiYPXs2lixZgoCAAPj6+mLBggXw9PTE+PHjG29LiYiIqMUyOWGZOHEiioqKsHDhQuTn5yM4OBhpaWnSoNlLly5BofjzxM2AAQOQkpKCN998E/Pnz0dAQAB27NiBXr16SXXmzp2LsrIyvPTSSyguLsbAgQORlpYGtVrdCJtYfzY2NoiPj692uamtYT+wDwD2AcA+0GM/sA8Ay/eBINbnXiIiIiIiC+K7hIiIiEj2mLAQERGR7DFhISIiItljwkJERESyx4SFiIiIZI8JSy2efPJJdOnSBWq1Gh4eHvjb3/6Ga9euWTqsZpObm4sXXngBvr6+sLW1hb+/P+Lj41FR0Tpe8lhfS5cuxYABA2BnZ9eoT1SWuzVr1sDHxwdqtRphYWE4dOiQpUNqVt9//z3Gjh0LT09PCILQ5t4en5SUhMceewyOjo7o1KkTxo8fj7Nnz1o6rGb3wQcfIDAwUHq6a//+/fG///3P0mFZ1LJly6RnqDUnJiy1GDZsGD777DOcPXsWX375JS5cuIBnnnnG0mE1mzNnzkCn0+HDDz/EyZMn8fbbb2PdunWYP3++pUNrVhUVFXj22WfxyiuvWDqUZpOamorY2FjEx8fj8OHDCAoKQmRkJAoLCy0dWrMpKytDUFAQ1qxZY+lQLOK7777DzJkzceDAAaSnp0Or1WLEiBEoKyuzdGjNqnPnzli2bBmys7Px888/Y/jw4Rg3bhxOnjxp6dAs4qeffsKHH36IwMDA5l+5SPX21VdfiYIgiBUVFZYOxWJWrFgh+vr6WjoMi9i4caPo7Oxs6TCaRd++fcWZM2dK01VVVaKnp6eYlJRkwagsB4C4fft2S4dhUYWFhSIA8bvvvrN0KBbXvn17cf369ZYOo9ndvn1bDAgIENPT08UhQ4aIs2bNatb18wxLPd28eRNbtmzBgAEDoFKpLB2Oxdy6dQsuLi6WDoOaUEVFBbKzsxEeHi6VKRQKhIeHIysry4KRkSXdunULANr0739VVRW2bt2KsrKyNvmuu5kzZ2L06NEG/zc0JyYsdXj99ddhb2+PDh064NKlS/jqq68sHZLF/PLLL3jvvffw97//3dKhUBO6fv06qqqqpNdt6Lm5uSE/P99CUZEl6XQ6zJ49G48//rjBa1XaiuPHj8PBwQE2NjZ4+eWXsX37dvTs2dPSYTWrrVu34vDhw9J7Ai2hzSUs8+bNgyAItf6cOXNGqj9nzhwcOXIEe/bsgVKpxNSpUyG28LcZmNoHAHD16lWMHDkSzz77LKZPn26hyBuPOX1A1FbNnDkTJ06cwNatWy0dikU89NBDyMnJwcGDB/HKK68gKioKp06dsnRYzeby5cuYNWsWtmzZ0uzv+Ltfm3uXUFFREW7cuFFrHT8/P1hbW1crv3LlCry9vbF///4WfTrQ1D64du0ahg4din79+mHTpk0GL7dsqczZDzZt2oTZs2ejuLi4iaOzrIqKCtjZ2eGLL74weGN6VFQUiouL2+RZRkEQsH379jb5BvmYmBh89dVX+P777+Hr62vpcGQhPDwc/v7++PDDDy0dSrPYsWMH/vrXv0KpVEplVVVVEAQBCoUCGo3GYF5TMfltzS1dx44d0bFjR7OW1el0AACNRtOYITU7U/rg6tWrGDZsGEJCQrBx48ZWkawADdsPWjtra2uEhIQgIyNDOkDrdDpkZGQgJibGssFRsxFFEf/4xz+wfft2ZGZmMlm5j06na/HHAVM88cQTOH78uEFZdHQ0evTogddff71ZkhWgDSYs9XXw4EH89NNPGDhwINq3b48LFy5gwYIF8Pf3b9FnV0xx9epVDB06FF27dsXKlStRVFQkzXN3d7dgZM3r0qVLuHnzJi5duoSqqirk5OQAALp16wYHBwfLBtdEYmNjERUVhdDQUPTt2xfJyckoKytDdHS0pUNrNqWlpfjll1+k6YsXLyInJwcuLi7o0qWLBSNrHjNnzkRKSgq++uorODo6SuOXnJ2dYWtra+Homk9cXBxGjRqFLl264Pbt20hJSUFmZiZ2795t6dCajaOjY7WxS/qxnc06pqlZ70lqQY4dOyYOGzZMdHFxEW1sbEQfHx/x5ZdfFq9cuWLp0JrNxo0bRQBGf9qSqKgoo32wd+9eS4fWpN577z2xS5cuorW1tdi3b1/xwIEDlg6pWe3du9fo9x4VFWXp0JpFTb/7GzdutHRozer5558Xu3btKlpbW4sdO3YUn3jiCXHPnj2WDsviLHFbc5sbw0JEREQtT+sYkEBEREStGhMWIiIikj0mLERERCR7TFiIiIhI9piwEBERkewxYSEiIiLZY8JCREREsseEhYiIiGSPCQsRERHJHhMWIiIikj0mLERERCR7/z9HhY5nYwKkDgAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for ctrl_s in ctrl_str:\n", + " plt.hist(\n", + " [r for r, t in zip(logs[\"env/reward_dist\"], task_list) if t == ctrl_s],\n", + " density=True,\n", + " alpha=0.5,\n", + " label=ctrl_s,\n", + " )\n", + "plt.legend(loc=\"best\")\n", + "plt.title(\"reward distribution\")\n", + "plt.grid(True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model to disk for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gpt2_model.save_pretrained(\"gpt2-imdb-ctrl\")\n", + "gpt2_tokenizer.save_pretrained(\"gpt2-imdb-ctrl\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "trl", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "vscode": { + "interpreter": { + "hash": "d2cfb53525227c89f8d14fa784301fa46c451cc9223d94ccce9e17956835eea2" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/notebooks/gpt2-sentiment.ipynb b/examples/notebooks/gpt2-sentiment.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..a5b6edc82100732fdf34638a8f9df753a9fe4c35 --- /dev/null +++ b/examples/notebooks/gpt2-sentiment.ipynb @@ -0,0 +1,861 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tune GPT2 to generate positive reviews\n", + "> Optimise GPT2 to produce positive IMDB movie reviews using a BERT sentiment classifier as a reward function." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "

Figure: Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face.

\n", + "
\n", + "\n", + "\n", + "In this notebook we fine-tune GPT2 (small) to generate positive movie reviews based on the IMDB dataset. The model gets the start of a real review and is tasked to produce positive continuations. To reward positive continuations we use a BERT classifier to analyse the sentiment of the produced sentences and use the classifier's outputs as rewards signals for PPO training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup experiment" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install transformers trl wandb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "\n", + "tqdm.pandas()\n", + "\n", + "from transformers import pipeline, AutoTokenizer\n", + "from datasets import load_dataset\n", + "\n", + "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead\n", + "from trl.core import LengthSampler" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = PPOConfig(\n", + " model_name=\"lvwerra/gpt2-imdb\",\n", + " learning_rate=1.41e-5,\n", + " log_with=\"wandb\",\n", + ")\n", + "\n", + "sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import wandb\n", + "\n", + "wandb.init()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n", + "https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load data and models" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load IMDB dataset\n", + "The IMDB dataset contains 50k movie review annotated with \"positive\"/\"negative\" feedback indicating the sentiment. We load the IMDB dataset into a DataFrame and filter for comments that are at least 200 characters. Then we tokenize each text and cut it to random size with the `LengthSampler`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def build_dataset(\n", + " config,\n", + " dataset_name=\"stanfordnlp/imdb\",\n", + " input_min_text_length=2,\n", + " input_max_text_length=8,\n", + "):\n", + " \"\"\"\n", + " Build dataset for training. This builds the dataset from `load_dataset`, one should\n", + " customize this function to train the model on its own dataset.\n", + "\n", + " Args:\n", + " dataset_name (`str`):\n", + " The name of the dataset to be loaded.\n", + "\n", + " Returns:\n", + " dataloader (`torch.utils.data.DataLoader`):\n", + " The dataloader for the dataset.\n", + " \"\"\"\n", + " tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + " # load imdb with datasets\n", + " ds = load_dataset(dataset_name, split=\"train\")\n", + " ds = ds.rename_columns({\"text\": \"review\"})\n", + " ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n", + "\n", + " input_size = LengthSampler(input_min_text_length, input_max_text_length)\n", + "\n", + " def tokenize(sample):\n", + " sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n", + " sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n", + " return sample\n", + "\n", + " ds = ds.map(tokenize, batched=False)\n", + " ds.set_format(type=\"torch\")\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(config)\n", + "\n", + "\n", + "def collator(data):\n", + " return dict((key, [d[key] for d in data]) for key in data[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load pre-trained GPT2 language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We load the GPT2 model with a value head and the tokenizer. We load the model twice; the first model is optimized while the second model serves as a reference to calculate the KL-divergence from the starting point. This serves as an additional reward signal in the PPO training to make sure the optimized model does not deviate too much from the original language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n", + "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n", + "\n", + "tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize PPOTrainer\n", + "The `PPOTrainer` takes care of device placement and optimization later on:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ppo_trainer = PPOTrainer(\n", + " config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load BERT classifier\n", + "We load a BERT classifier fine-tuned on the IMDB dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "device = ppo_trainer.accelerator.device\n", + "if ppo_trainer.accelerator.num_processes == 1:\n", + " device = 0 if torch.cuda.is_available() else \"cpu\" # to avoid a `pipeline` bug\n", + "sentiment_pipe = pipeline(\n", + " \"sentiment-analysis\", model=\"lvwerra/distilbert-imdb\", device=device\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model outputs are the logits for the negative and positive class. We will use the logits for positive class as a reward signal for the language model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n", + " {'label': 'POSITIVE', 'score': -2.726576328277588}]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really bad!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'label': 'POSITIVE', 'score': 2.557040214538574},\n", + " {'label': 'NEGATIVE', 'score': -2.294790267944336}]" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "text = \"this movie was really good!!\"\n", + "sentiment_pipe(text, **sent_kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generation settings\n", + "For the response generation we just use sampling and make sure top-k and nucleus sampling are turned off as well as a minimal length." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gen_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.eos_token_id,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Optimize model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training loop" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training loop consists of the following main steps:\n", + "1. Get the query responses from the policy network (GPT-2)\n", + "2. Get sentiments for query/responses from BERT\n", + "3. Optimize policy with PPO using the (query, response, reward) triplet\n", + "\n", + "**Training time**\n", + "\n", + "This step takes **~2h** on a V100 GPU with the above specified settings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_min_length = 4\n", + "output_max_length = 16\n", + "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n", + "\n", + "\n", + "generation_kwargs = {\n", + " \"min_length\": -1,\n", + " \"top_k\": 0.0,\n", + " \"top_p\": 1.0,\n", + " \"do_sample\": True,\n", + " \"pad_token_id\": tokenizer.eos_token_id,\n", + "}\n", + "\n", + "\n", + "for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):\n", + " query_tensors = batch[\"input_ids\"]\n", + "\n", + " #### Get response from gpt2\n", + " response_tensors = []\n", + " for query in query_tensors:\n", + " gen_len = output_length_sampler()\n", + " generation_kwargs[\"max_new_tokens\"] = gen_len\n", + " query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()\n", + " response_len = len(query_response) - len(query)\n", + " response_tensors.append(query_response[-response_len:])\n", + " batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n", + "\n", + " #### Compute sentiment score\n", + " texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n", + " pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", + " positive_scores = [\n", + " item[\"score\"]\n", + " for output in pipe_outputs\n", + " for item in output\n", + " if item[\"label\"] == \"POSITIVE\"\n", + " ]\n", + " rewards = [torch.tensor(score) for score in positive_scores]\n", + "\n", + " #### Run PPO step\n", + " stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n", + " ppo_trainer.log_stats(stats, batch, rewards)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training progress\n", + "If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://wandb.ai/huggingface/trl/runs/w9l3110g).\n", + "\n", + "
\n", + "\n", + "

Figure: Reward mean and distribution evolution during training.

\n", + "
\n", + "\n", + "One can observe how the model starts to generate more positive outputs after a few optimisation steps.\n", + "\n", + "> Note: Investigating the KL-divergence will probably show that at this point the model has not converged to the target KL-divergence, yet. To get there would require longer training or starting with a higher initial coefficient." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model inspection\n", + "Let's inspect some examples from the IMDB dataset. We can use `ref_model` to compare the tuned model `model` against the model before optimisation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
queryresponse (before)response (after)rewards (before)rewards (after)
0I rented Zero Day4 for my sister. To my surprise, the Wii caug.... It is a pleasure. It is a huge leap 68 years...1.7360682.423731
1The onlydistro of herspecial compliments is the0.1508520.190159
2I've read a fewnews reports about Mr. Mueller's activities b...novels and I never watch this. It has a reall...-1.4179622.831814
3This is the second British Rank film, and I wouldn't be surprised anymore if itthat I have enjoyed, achieving it in both the0.8358762.205628
4A classicclassic.<br /><br />And only this one will ha.... It's a movie with a fine cast. As the beginn...2.1130752.739168
5This has to be one of theworst with the differences being that for thebest thriller films I've seen in recent-2.7053392.730615
6Happy Go Lovely is a waste. Not only are extremelyof time, giving a-2.429504-2.934672
7Wow, I justcan't make fun of itfeek it! This show-2.201666-0.106085
8This movie makes several mistakes.Despite being a great comedic diversion it es...It's cool, wonderful - it held me into a very ...-1.2323802.707638
9Branagh and Fishburne, Drake is playedis a great show. Beautiful0.7768192.808996
10I might have given this movie arating of *11 when I heard that!), but it was...great performance. It was truly a great movie...0.2763802.743328
11Really, really badwith feel like there is no end to the. This movie is incredibly good, with the-2.639503-1.568827
12What another reviewer called lack ofjudgment, connecting into her own harsh obser...suspense. Rogers and Rooney rate this as exce...-1.0797072.696888
13This is simply onemore problem of Steveof the best choice-1.4454362.662699
14\"Perhaps we can arrange a meet-and-greet.<br /><br />Telegwith spent, classic music and dance, and come...0.2584791.876662
15Richard Willaims isnice enough; the little black guy plays quitebeautifully hands on in his own spin, and0.7965082.820259
\n", + "
" + ], + "text/plain": [ + " query \\\n", + "0 I rented Zero Day \n", + "1 The only \n", + "2 I've read a few \n", + "3 This is the second British Rank film \n", + "4 A classic \n", + "5 This has to be one of the \n", + "6 Happy Go Lovely is a waste \n", + "7 Wow, I just \n", + "8 This movie makes several mistakes. \n", + "9 Branagh and Fish \n", + "10 I might have given this movie a \n", + "11 Really, really bad \n", + "12 What another reviewer called lack of \n", + "13 This is simply one \n", + "14 \"Perhaps we can arrange a meet \n", + "15 Richard Willaims is \n", + "\n", + " response (before) \\\n", + "0 4 for my sister. To my surprise, the Wii caug... \n", + "1 distro of her \n", + "2 news reports about Mr. Mueller's activities b... \n", + "3 , and I wouldn't be surprised anymore if it \n", + "4 classic.

And only this one will ha... \n", + "5 worst with the differences being that for the \n", + "6 . Not only are extremely \n", + "7 can't make fun of it \n", + "8 Despite being a great comedic diversion it es... \n", + "9 burne, Drake is played \n", + "10 rating of *11 when I heard that!), but it was... \n", + "11 with feel like there is no end to the \n", + "12 judgment, connecting into her own harsh obser... \n", + "13 more problem of Steve \n", + "14 -and-greet.

Teleg \n", + "15 nice enough; the little black guy plays quite \n", + "\n", + " response (after) rewards (before) \\\n", + "0 . It is a pleasure. It is a huge leap 68 years... 1.736068 \n", + "1 special compliments is the 0.150852 \n", + "2 novels and I never watch this. It has a reall... -1.417962 \n", + "3 that I have enjoyed, achieving it in both the 0.835876 \n", + "4 . It's a movie with a fine cast. As the beginn... 2.113075 \n", + "5 best thriller films I've seen in recent -2.705339 \n", + "6 of time, giving a -2.429504 \n", + "7 feek it! This show -2.201666 \n", + "8 It's cool, wonderful - it held me into a very ... -1.232380 \n", + "9 is a great show. Beautiful 0.776819 \n", + "10 great performance. It was truly a great movie... 0.276380 \n", + "11 . This movie is incredibly good, with the -2.639503 \n", + "12 suspense. Rogers and Rooney rate this as exce... -1.079707 \n", + "13 of the best choice -1.445436 \n", + "14 with spent, classic music and dance, and come... 0.258479 \n", + "15 beautifully hands on in his own spin, and 0.796508 \n", + "\n", + " rewards (after) \n", + "0 2.423731 \n", + "1 0.190159 \n", + "2 2.831814 \n", + "3 2.205628 \n", + "4 2.739168 \n", + "5 2.730615 \n", + "6 -2.934672 \n", + "7 -0.106085 \n", + "8 2.707638 \n", + "9 2.808996 \n", + "10 2.743328 \n", + "11 -1.568827 \n", + "12 2.696888 \n", + "13 2.662699 \n", + "14 1.876662 \n", + "15 2.820259 " + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#### get a batch from the dataset\n", + "bs = 16\n", + "game_data = dict()\n", + "dataset.set_format(\"pandas\")\n", + "df_batch = dataset[:].sample(bs)\n", + "game_data[\"query\"] = df_batch[\"query\"].tolist()\n", + "query_tensors = df_batch[\"input_ids\"].tolist()\n", + "\n", + "response_tensors_ref, response_tensors = [], []\n", + "\n", + "#### get response from gpt2 and gpt2_ref\n", + "for i in range(bs):\n", + " query = torch.tensor(query_tensors[i]).to(device)\n", + "\n", + " gen_len = output_length_sampler()\n", + " query_response = ref_model.generate(\n", + " query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()\n", + " response_len = len(query_response) - len(query)\n", + " response_tensors_ref.append(query_response[-response_len:])\n", + "\n", + " query_response = model.generate(\n", + " query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs\n", + " ).squeeze()\n", + " response_len = len(query_response) - len(query)\n", + " response_tensors.append(query_response[-response_len:])\n", + "\n", + "#### decode responses\n", + "game_data[\"response (before)\"] = [\n", + " tokenizer.decode(response_tensors_ref[i]) for i in range(bs)\n", + "]\n", + "game_data[\"response (after)\"] = [\n", + " tokenizer.decode(response_tensors[i]) for i in range(bs)\n", + "]\n", + "\n", + "#### sentiment analysis of query/response pairs before/after\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n", + "pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", + "positive_scores = [\n", + " item[\"score\"]\n", + " for output in pipe_outputs\n", + " for item in output\n", + " if item[\"label\"] == \"POSITIVE\"\n", + "]\n", + "game_data[\"rewards (before)\"] = positive_scores\n", + "\n", + "texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n", + "pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n", + "positive_scores = [\n", + " item[\"score\"]\n", + " for output in pipe_outputs\n", + " for item in output\n", + " if item[\"label\"] == \"POSITIVE\"\n", + "]\n", + "game_data[\"rewards (after)\"] = positive_scores\n", + "\n", + "# store results in a dataframe\n", + "df_results = pd.DataFrame(game_data)\n", + "df_results" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Looking at the reward mean/median of the generated sequences we observe a significant difference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) -0.512965\n", + "rewards (after) 1.676750\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "median:\n" + ] + }, + { + "data": { + "text/plain": [ + "rewards (before) -0.464427\n", + "rewards (after) 2.679794\n", + "dtype: float64" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(\"mean:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n", + "print()\n", + "print(\"median:\")\n", + "display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Save model\n", + "Finally, we save the model and push it to the Hugging Face for later usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('gpt2-imdb-pos-v2/tokenizer_config.json',\n", + " 'gpt2-imdb-pos-v2/special_tokens_map.json',\n", + " 'gpt2-imdb-pos-v2/vocab.json',\n", + " 'gpt2-imdb-pos-v2/merges.txt',\n", + " 'gpt2-imdb-pos-v2/added_tokens.json',\n", + " 'gpt2-imdb-pos-v2/tokenizer.json')" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n", + "tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "vscode": { + "interpreter": { + "hash": "4c8ff454cd947027f86954d72bf940c689a97dcc494eb53cfe4813862c6065fe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/research_projects/README.md b/examples/research_projects/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1b1977e1877ca1d6351cd888b76793a2bad3206d --- /dev/null +++ b/examples/research_projects/README.md @@ -0,0 +1,7 @@ +# Research projects that use TRL + +Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information! + +- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity) +- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama) +- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2) \ No newline at end of file diff --git a/examples/research_projects/layer_skip/README.md b/examples/research_projects/layer_skip/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b673a48c483d5341ceb4457c7e751d923ead0aaa --- /dev/null +++ b/examples/research_projects/layer_skip/README.md @@ -0,0 +1,15 @@ +# LayerSkip Training Recipe + +Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710). + +## Run training +``` +cd scripts +python layer_skip_sft.py +``` + +## Run benchmark +``` +cd scripts +python benchmark_layer_skip.py +``` \ No newline at end of file diff --git a/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py b/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py new file mode 100644 index 0000000000000000000000000000000000000000..9359dda33eb048515bb2f45d5a89b4a998af41ff --- /dev/null +++ b/examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py @@ -0,0 +1,77 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import config +import torch +from torch.utils import benchmark +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def generate_tokens(model, inputs): + outputs = model.generate( + **inputs, + do_sample=False, + max_new_tokens=64, + ) + return outputs + + +def generate_tokens_with_assistance(model, inputs, assistant_early_exit): + outputs = model.generate( + **inputs, + assistant_early_exit=assistant_early_exit, + do_sample=False, + max_new_tokens=64, + ) + return outputs + + +if __name__ == "__main__": + ckpt = config.hub_model_id + + model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(ckpt) + + prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: " + + results = [] + label = "Generation Times" + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + results.append( + benchmark.Timer( + stmt="generate_tokens(model, inputs)", + setup="from __main__ import generate_tokens", + globals={"model": model, "inputs": inputs}, + num_threads=torch.get_num_threads(), + label=label, + sub_label="no layer skip", + description="generation", + ).blocked_autorange() + ) + + for i in range(1, model.config.num_hidden_layers): + results.append( + benchmark.Timer( + stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)", + setup="from __main__ import generate_assistant_tokens", + globals={"model": model, "assistant_early_exit": i, "inputs": inputs}, + num_threads=torch.get_num_threads(), + label=label, + sub_label=f"layer skip {i}", + description="generation", + ).blocked_autorange() + ) + + benchmark.Compare(results).print() diff --git a/examples/research_projects/layer_skip/scripts/config.py b/examples/research_projects/layer_skip/scripts/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9870513df0f93f44db3359d15f18a60f8d6547 --- /dev/null +++ b/examples/research_projects/layer_skip/scripts/config.py @@ -0,0 +1,28 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from huggingface_hub import whoami + + +model_name = "unsloth/Llama-3.2-3B" +tokenizer_name = "unsloth/Llama-3.2-3B" +dataset_name = "WillHeld/top_v2" + +output_root_dir = "./checkpoints/" +hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}" +output_dir = f"{output_root_dir}/{hub_model_id}" + +per_device_train_batch_size = 8 +gradient_accumulation_steps = 1 +learning_rate = 2e-5 diff --git a/examples/research_projects/layer_skip/scripts/custom_trainer.py b/examples/research_projects/layer_skip/scripts/custom_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8e52b6e23e6fcfa4f850d8bc74c75c9fc46424f4 --- /dev/null +++ b/examples/research_projects/layer_skip/scripts/custom_trainer.py @@ -0,0 +1,48 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trl import SFTTrainer + + +class LayerSkipSFTTrainer(SFTTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.early_exit_layer = 0 # initialize with 0 + self.always_last_layer = True + self.early_exit_loss_scale = 1.0 + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + self.early_exit_layer = ( + self.early_exit_layer % (model.config.num_hidden_layers - 1) + ) + 1 # rotates between [1, num_hidden_layers-1] + bs, seqlen = inputs.input_ids.shape + + labels = inputs.pop("labels") + outputs = model(**inputs, output_hidden_states=True) + + hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype) + if self.early_exit_layer != model.config.num_hidden_layers: + hidden_state = model.model.norm(hidden_state) + logits = model.lm_head(hidden_state) + loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size) + + if self.always_last_layer: + loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size) + loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last + # normalize loss scales + loss = loss / (1.0 + self.early_exit_loss_scale) + else: + loss = loss_early + + return loss diff --git a/examples/research_projects/layer_skip/scripts/layer_skip_sft.py b/examples/research_projects/layer_skip/scripts/layer_skip_sft.py new file mode 100644 index 0000000000000000000000000000000000000000..bddb36b620ad34b373e70ffa757475dbd1d3670d --- /dev/null +++ b/examples/research_projects/layer_skip/scripts/layer_skip_sft.py @@ -0,0 +1,91 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import config +import torch +from custom_trainer import LayerSkipSFTTrainer +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import DataCollatorForCompletionOnlyLM, SFTConfig + + +def formatting_prompts_func(example): + text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}" + + # Inject eos_token as a string before tokenization, because they are not always added + # See: https://github.com/huggingface/transformers/issues/22794 and + # https://github.com/huggingface/trl/issues/1623 + if tokenizer.eos_token: # usually something like "" for GPT2 or "<|endoftext|>" + text += f"{tokenizer.eos_token}" + + return text + + +if __name__ == "__main__": + # load the dataset + print("[INFO] loading the dataset...") + train_dataset = load_dataset(config.dataset_name, split="train") + + print(f"output_root_dir: {config.output_root_dir}") + print(f"hub_model_id: {config.hub_model_id}") + + # load the model and tokenizer + print("[INFO] loading the model and tokenizer...") + model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True) + + # adding pad and eos tokens if not provided in the tokenizer + if tokenizer.pad_token is None: + # Add '[PAD]' token if it doesn't exist + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model.resize_token_embeddings(len(tokenizer)) + model.config.pad_token_id = tokenizer.pad_token_id + + if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token: + # Add '[EOS]' token if it doesn't exist + tokenizer.add_special_tokens({"eos_token": "[EOS]"}) + model.resize_token_embeddings(len(tokenizer)) + model.config.eos_token_id = tokenizer.eos_token_id + + response_template = " ### Response:" + collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + + args = SFTConfig( + do_train=True, + bf16=True, + max_seq_length=None, + per_device_train_batch_size=config.per_device_train_batch_size, + gradient_accumulation_steps=config.gradient_accumulation_steps, + learning_rate=config.learning_rate, + packing=False, + num_train_epochs=1.0, + report_to="none", + push_to_hub=True, + hub_model_id=config.hub_model_id, + output_dir=config.output_dir, + logging_steps=500, + save_steps=1000, + save_total_limit=2, + ) + + trainer = LayerSkipSFTTrainer( + model, + train_dataset=train_dataset, + args=args, + formatting_func=formatting_prompts_func, + data_collator=collator, + ) + + trainer.train() diff --git a/examples/research_projects/stack_llama/scripts/README.md b/examples/research_projects/stack_llama/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..da9f067f20cc73ae14889ec6d40110a9c79598e3 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/README.md @@ -0,0 +1,18 @@ +# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model. +There were three main steps to the training process: +1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` +2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm: + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=` +3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model: + - `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name= --reward_model_name= --adafactor=False --tokenizer_name= --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam` + + +LoRA layers were using at all stages to reduce memory requirements. +At each stage the peft adapter layers were merged with the base model, using: +```shell +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ +``` +Note that this script requires `peft>=0.3.0`. + +For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). diff --git a/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..9c7974ae0fab94f0da9346d418398a12ac934662 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py @@ -0,0 +1,62 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from peft import PeftConfig, PeftModel +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + + +@dataclass +class ScriptArguments: + """ + The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the + merged model. + """ + + adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) + base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) + output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" +assert script_args.base_model_name is not None, "please provide the name of the Base model" +assert script_args.output_name is not None, "please provide the output name of the merged model" + +peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) +if peft_config.task_type == "SEQ_CLS": + # The sequence classification task is used for the reward model in PPO + model = AutoModelForSequenceClassification.from_pretrained( + script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 + ) +else: + model = AutoModelForCausalLM.from_pretrained( + script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 + ) + +tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) + +# Load the PEFT model +model = PeftModel.from_pretrained(model, script_args.adapter_model_name) +model.eval() + +model = model.merge_and_unload() + +model.save_pretrained(f"{script_args.output_name}") +tokenizer.save_pretrained(f"{script_args.output_name}") +model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) diff --git a/examples/research_projects/stack_llama/scripts/reward_modeling.py b/examples/research_projects/stack_llama/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..e45781b66d9d95d0f7ffa6900a6e8ae175ea11e1 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/reward_modeling.py @@ -0,0 +1,324 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional, Union + +import evaluate +import numpy as np +import torch +import torch.nn as nn +from datasets import load_dataset +from peft import LoraConfig, TaskType, get_peft_model +from transformers import ( + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainingArguments, + set_seed, +) +from transformers.utils import PaddingStrategy + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train. + """ + + local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"}) + resume_from_checkpoint: Optional[bool] = field( + default=False, + metadata={"help": "If you want to resume training where it left off."}, + ) + deepspeed: Optional[str] = field( + default=None, + metadata={ + "help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU." + }, + ) + per_device_train_batch_size: Optional[int] = field(default=4) + per_device_eval_batch_size: Optional[int] = field(default=1) + gradient_accumulation_steps: Optional[int] = field(default=1) + learning_rate: Optional[float] = field(default=2e-5) + weight_decay: Optional[float] = field(default=0.001) + model_name: Optional[str] = field( + default="gpt2", + metadata={ + "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "The tokenizer for your model, if left empty will use the default for your model", + }, + ) + bf16: Optional[bool] = field( + default=True, + metadata={ + "help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU." + }, + ) + num_train_epochs: Optional[int] = field( + default=1, + metadata={"help": "The number of training epochs for the reward model."}, + ) + train_subset: Optional[int] = field( + default=100000, + metadata={"help": "The size of the subset of the training data to use"}, + ) + eval_subset: Optional[int] = field( + default=50000, + metadata={"help": "The size of the subset of the eval data to use"}, + ) + gradient_checkpointing: Optional[bool] = field( + default=False, + metadata={"help": "Enables gradient checkpointing."}, + ) + optim: Optional[str] = field( + default="adamw_hf", + metadata={"help": "The optimizer to use."}, + ) + lr_scheduler_type: Optional[str] = field( + default="linear", + metadata={"help": "The lr scheduler"}, + ) + max_length: Optional[int] = field(default=512) + eval_first_step: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run eval after the first step"}, + ) + seed: Optional[int] = field( + default=0, metadata={"help": "Random seed that will be set at the beginning of training."} + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] +set_seed(script_args.seed) +# Load the human stack-exchange-paired dataset for tuning the reward model. +train_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks" +) +if script_args.train_subset > 0: + train_dataset = train_dataset.select(range(script_args.train_subset)) +eval_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks" +) +if script_args.eval_subset > 0: + eval_dataset = eval_dataset.select(range(script_args.eval_subset)) +# Define the training args. Needs to be done before the model is loaded if you are using deepspeed. +model_name_split = script_args.model_name.split("/")[-1] +output_name = ( + f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}" +) + +training_args = TrainingArguments( + output_dir=output_name, + learning_rate=script_args.learning_rate, + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + num_train_epochs=script_args.num_train_epochs, + weight_decay=script_args.weight_decay, + eval_strategy="steps", + eval_steps=500, + save_strategy="steps", + save_steps=500, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + deepspeed=script_args.deepspeed, + local_rank=script_args.local_rank, + remove_unused_columns=False, + label_names=[], + bf16=script_args.bf16, + logging_strategy="steps", + logging_steps=10, + optim=script_args.optim, + lr_scheduler_type=script_args.lr_scheduler_type, + seed=script_args.seed, +) + + +# Load the value-head model and tokenizer. +tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name +tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + + +peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, +) + +model = AutoModelForSequenceClassification.from_pretrained( + script_args.model_name, num_labels=1, torch_dtype=torch.bfloat16 +) +model = get_peft_model(model, peft_config) +model.print_trainable_parameters() + +# Need to do this for gpt2, because it doesn't have an official pad token. +tokenizer.pad_token = tokenizer.eos_token +model.config.pad_token_id = tokenizer.eos_token_id +model.config.use_cache = not script_args.gradient_checkpointing +num_proc = 24 # Can adjust to be higher if you have more processors. +original_columns = train_dataset.column_names + + +# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other. +# Then tokenize the dataset. +def preprocess_function(examples): + new_examples = { + "input_ids_j": [], + "attention_mask_j": [], + "input_ids_k": [], + "attention_mask_k": [], + } + for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]): + tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True) + tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True) + + new_examples["input_ids_j"].append(tokenized_j["input_ids"]) + new_examples["attention_mask_j"].append(tokenized_j["attention_mask"]) + new_examples["input_ids_k"].append(tokenized_k["input_ids"]) + new_examples["attention_mask_k"].append(tokenized_k["attention_mask"]) + + return new_examples + + +# preprocess the dataset and filter out QAs that are longer than script_args.max_length +train_dataset = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length, + num_proc=num_proc, +) + +eval_dataset = eval_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, +) +eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length, + num_proc=num_proc, +) + + +# We need to define a special data collator that batches the data in our j vs k format. +@dataclass +class RewardDataCollatorWithPadding: + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + features_j = [] + features_k = [] + for feature in features: + features_j.append( + { + "input_ids": feature["input_ids_j"], + "attention_mask": feature["attention_mask_j"], + } + ) + features_k.append( + { + "input_ids": feature["input_ids_k"], + "attention_mask": feature["attention_mask_k"], + } + ) + batch_j = self.tokenizer.pad( + features_j, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_k = self.tokenizer.pad( + features_k, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_j": batch_j["input_ids"], + "attention_mask_j": batch_j["attention_mask"], + "input_ids_k": batch_k["input_ids"], + "attention_mask_k": batch_k["attention_mask"], + "return_loss": True, + } + return batch + + +# Define the metric that we'll use for validation. +accuracy = evaluate.load("accuracy") + + +def compute_metrics(eval_pred): + predictions, _ = eval_pred + # Here, predictions is rewards_j and rewards_k. + # We want to see how much of the time rewards_j > rewards_k. + predictions = np.argmax(predictions, axis=0) + labels = np.zeros(predictions.shape) + return accuracy.compute(predictions=predictions, references=labels) + + +class RewardTrainer(Trainer): + # Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155 + def compute_loss(self, model, inputs, return_outputs=False): + rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0] + rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0] + loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean() + if return_outputs: + return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k} + return loss + + +# Train the model, woohoo. +trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer), +) + + +if script_args.eval_first_step: + + class EvaluateFirstStepCallback(TrainerCallback): + def on_step_end(self, args, state, control, **kwargs): + if state.global_step == 1: + control.should_evaluate = True + + trainer.add_callback(EvaluateFirstStepCallback()) + +trainer.train(script_args.resume_from_checkpoint) + +print("Saving last checkpoint of the model") +model.save_pretrained(output_name + "_peft_last_checkpoint") diff --git a/examples/research_projects/stack_llama/scripts/rl_training.py b/examples/research_projects/stack_llama/scripts/rl_training.py new file mode 100644 index 0000000000000000000000000000000000000000..c67adbe6d658aa6d40e3c67b0288d052b9d56247 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/rl_training.py @@ -0,0 +1,268 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer +from trl.core import LengthSampler + + +tqdm.pandas() + + +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine-tune with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="", metadata={"help": "the model name"}) + tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"}) + reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"}) + output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"}) + early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"}) + target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"}) + reward_baseline: Optional[float] = field( + default=0.0, + metadata={"help": "a baseline value that is subtracted from the reward"}, + ) + batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"}) + save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"}) + output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"}) + seed: Optional[int] = field(default=0, metadata={"help": "the seed"}) + steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"}) + init_kl_coef: Optional[float] = field( + default=0.2, + metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"}, + ) + + adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"}) + load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"}) + + +parser = HfArgumentParser(ScriptArguments) +script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0] +reward_model_name = script_args.reward_model_name +dataset_name = "lvwerra/stack-exchange-paired" +config = PPOConfig( + steps=script_args.steps, + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + batch_size=script_args.batch_size, + mini_batch_size=script_args.mini_batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + optimize_device_cache=True, + early_stopping=script_args.early_stopping, + target_kl=script_args.target_kl, + ppo_epochs=script_args.ppo_epochs, + seed=script_args.seed, + init_kl_coef=script_args.init_kl_coef, + adap_kl_ctrl=script_args.adap_kl_ctrl, +) + +train_dataset = load_dataset( + "lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks" +) +train_dataset = train_dataset.select(range(100000)) +original_columns = train_dataset.column_names + +# We then define the arguments to pass to the sentiment analysis pipeline. +# We set `return_all_scores` to True to get the sentiment score for each token. +sent_kwargs = { + "return_all_scores": True, + "function_to_apply": "none", + "batch_size": 16, + "truncation": True, +} + +tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name) +# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. + +if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + tokenizer, + dataset_name="lvwerra/stack-exchange-paired", +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + + num_proc = 24 + + def preprocess_function(examples): + new_examples = { + "query": [], + "input_ids": [], + } + for question in examples["question"]: + query = "Question: " + question + "\n\nAnswer: " + tokenized_question = tokenizer(query, truncation=True) + new_examples["query"].append(query) + new_examples["input_ids"].append(tokenized_question["input_ids"]) + + return new_examples + + ds = train_dataset.map( + preprocess_function, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc) + + ds.set_format(type="torch") + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +dataset = build_dataset(tokenizer) + + +def collator(data): + return {key: [d[key] for d in data] for key in data[0]} + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. +current_device = Accelerator().local_process_index + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", +) +model = AutoModelForCausalLMWithValueHead.from_pretrained( + config.model_name, + load_in_8bit=script_args.load_in_8bit, + device_map={"": current_device}, + peft_config=lora_config, +) + +optimizer = None +if script_args.adafactor: + optimizer = Adafactor( + filter(lambda p: p.requires_grad, model.parameters()), + scale_parameter=False, + relative_step=False, + warmup_init=False, + lr=config.learning_rate, + ) +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=None, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the sentiment analysis pipeline using our reward model, passing the +# model name and the sentiment analysis pipeline arguments. Let's also make sure to +# set the device to the same device as the PPOTrainer. +device = ppo_trainer.accelerator.device +if ppo_trainer.accelerator.num_processes == 1: + device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug +sentiment_pipe = pipeline( + "sentiment-analysis", + model=reward_model_name, + device_map={"": current_device}, + model_kwargs={"load_in_8bit": script_args.load_in_8bit}, + tokenizer=tokenizer, + return_token_type_ids=False, +) + +if sentiment_pipe.model.config.pad_token_id is None: + sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + # "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.pad_token_id, + "eos_token_id": 100_000, +} +output_min_length = 32 +output_max_length = script_args.output_max_length +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + if epoch >= config.total_ppo_epochs: + break + + question_tensors = batch["input_ids"] + + response_tensors = ppo_trainer.generate( + question_tensors, + return_prompt=False, + length_sampler=output_length_sampler, + **generation_kwargs, + ) + batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True) + + # Compute reward score (using the sentiment analysis pipeline) + texts = [q + r for q, r in zip(batch["query"], batch["response"])] + pipe_outputs = sentiment_pipe(texts, **sent_kwargs) + rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs] + + # Run PPO step + stats = ppo_trainer.step(question_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + if script_args.save_freq and epoch and epoch % script_args.save_freq == 0: + ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}") diff --git a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py new file mode 100644 index 0000000000000000000000000000000000000000..85714ce2c2ab2fbcff22c002d0997fd595b56979 --- /dev/null +++ b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py @@ -0,0 +1,222 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from accelerate import Accelerator +from datasets import load_dataset +from peft import LoraConfig +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed + +from trl import SFTTrainer +from trl.trainer import ConstantLengthDataset + + +""" +Fine-Tune Llama-7b on SE paired dataset +""" + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, default="") + parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired") + parser.add_argument("--subset", type=str, default="data/finetune") + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--size_valid_set", type=int, default=4000) + parser.add_argument("--streaming", action="store_true") + parser.add_argument("--shuffle_buffer", type=int, default=5000) + + parser.add_argument("--seq_length", type=int, default=1024) + parser.add_argument("--max_steps", type=int, default=10000) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--eos_token_id", type=int, default=49152) + + parser.add_argument("--learning_rate", type=float, default=1e-4) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--num_warmup_steps", type=int, default=100) + parser.add_argument("--weight_decay", type=float, default=0.05) + + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--fp16", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--num_workers", type=int, default=None) + parser.add_argument("--output_dir", type=str, default="./checkpoints") + parser.add_argument("--log_freq", default=1, type=int) + parser.add_argument("--eval_freq", default=1000, type=int) + parser.add_argument("--save_freq", default=1000, type=int) + + return parser.parse_args() + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=args.seed) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +def run_training(args, train_data, val_data): + print("Loading the model") + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + train_data.start_iteration = 0 + + print("Starting main loop") + + training_args = TrainingArguments( + output_dir=args.output_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=args.max_steps, + eval_steps=args.eval_freq, + save_steps=args.save_freq, + logging_steps=args.log_freq, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + learning_rate=args.learning_rate, + lr_scheduler_type=args.lr_scheduler_type, + warmup_steps=args.num_warmup_steps, + gradient_accumulation_steps=args.gradient_accumulation_steps, + gradient_checkpointing=args.gradient_checkpointing, + fp16=args.fp16, + bf16=args.bf16, + weight_decay=args.weight_decay, + run_name="llama-7b-finetuned", + report_to="wandb", + ddp_find_unused_parameters=False, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index} + ) + + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=train_data, + eval_dataset=val_data, + peft_config=lora_config, + packing=True, + ) + + print_trainable_parameters(trainer.model) + + print("Training...") + trainer.train() + + print("Saving last checkpoint of the model") + trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/")) + + +def main(args): + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + train_dataset, eval_dataset = create_datasets(tokenizer, args) + run_training(args, train_dataset, eval_dataset) + + +if __name__ == "__main__": + args = get_args() + assert args.model_path != "", "Please provide the llama model path" + + set_seed(args.seed) + os.makedirs(args.output_dir, exist_ok=True) + + logging.set_verbosity_error() + + main(args) diff --git a/examples/research_projects/stack_llama_2/scripts/README.md b/examples/research_projects/stack_llama_2/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..727a631d8d120f25f4605d93e97539443fd5da8d --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/README.md @@ -0,0 +1,76 @@ +# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model + +## Prerequisites + +Install all the dependencies in the `requirements.txt`: + +``` +$ pip install -U -r requirements.txt +``` + +Since we will use `accelerate` for training, make sure to run: +``` +$ accelerate config +``` + +## Training + +There were two main steps to the DPO training process: +1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: + + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \ + --output_dir="./sft" \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=10 \ + --per_device_train_batch_size=4 \ + --per_device_eval_batch_size=1 \ + --gradient_accumulation_steps=2 \ + --gradient_checkpointing=False \ + --group_by_length=False \ + --learning_rate=1e-4 \ + --lr_scheduler_type="cosine" \ + --warmup_steps=100 \ + --weight_decay=0.05 \ + --optim="paged_adamw_32bit" \ + --bf16=True \ + --remove_unused_columns=False \ + --run_name="sft_llama2" \ + --report_to="wandb" + ``` +1. Run the DPO trainer using the model saved by the previous step: + ``` + accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \ + --model_name_or_path="sft/final_checkpoint" \ + --output_dir="dpo" + ``` + + +## Merging the adaptors + +To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: + +``` +python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2" +``` + +which will also push the model to your HuggingFace hub account. + +## Running the model + +We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: + +```py +from peft import AutoPeftModelForCausalLM + + +model = AutoPeftModelForCausalLM.from_pretrained( + "dpo/final_checkpoint", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + load_in_4bit=True, +) + +model.generate(...) +``` diff --git a/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..43f9d35b3e477d2598a3d3b1d7d861894f601076 --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py @@ -0,0 +1,252 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# 0. imports +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import Dataset, load_dataset +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed + +from trl import DPOConfig, DPOTrainer + + +# Define and parse arguments. +@dataclass +class ScriptArguments: + """ + The arguments for the DPO training script. + """ + + # data parameters + beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + + # training parameters + model_name_or_path: Optional[str] = field( + default="../sft/results/final_checkpoint", + metadata={"help": "the location of the SFT model name or path"}, + ) + learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"}) + lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"}) + warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"}) + weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"}) + optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"}) + + per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"}) + per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"}) + gradient_accumulation_steps: Optional[int] = field( + default=4, metadata={"help": "the number of gradient accumulation steps"} + ) + gradient_checkpointing: Optional[bool] = field( + default=True, metadata={"help": "whether to use gradient checkpointing"} + ) + + gradient_checkpointing_use_reentrant: Optional[bool] = field( + default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"} + ) + + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"}) + max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"}) + max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"}) + logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"}) + save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"}) + eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"}) + + output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"}) + log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"}) + load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"}) + model_dtype: Optional[str] = field( + default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."} + ) + + # instrumentation + report_to: Optional[str] = field( + default="wandb", + metadata={ + "help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,' + '`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. ' + 'Use `"all"` to report to all integrations installed, `"none"` for no integrations.' + }, + ) + # debug argument for distributed training + ignore_bias_buffers: Optional[bool] = field( + default=False, + metadata={ + "help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + seed: Optional[int] = field( + default=0, metadata={"help": "Random seed that will be set at the beginning of training."} + ) + + +def get_stack_exchange_paired( + data_dir: str = "data/rl", + cache_dir: Optional[str] = None, + num_proc=24, +) -> Dataset: + """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format. + + The dataset is converted to a dictionary with the following structure: + { + 'prompt': list[str], + 'chosen': list[str], + 'rejected': list[str], + } + + Prompts are structured as follows: + "Question: " + + "\n\nAnswer: " + """ + dataset = load_dataset( + "lvwerra/stack-exchange-paired", + split="train", + cache_dir=cache_dir, + data_dir=data_dir, + verification_mode="no_checks", + ) + original_columns = dataset.column_names + + def return_prompt_and_responses(samples) -> dict[str, str]: + return { + "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]], + "chosen": samples["response_j"], + "rejected": samples["response_k"], + } + + return dataset.map( + return_prompt_and_responses, + batched=True, + num_proc=num_proc, + remove_columns=original_columns, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + + set_seed(script_args.seed) + + # 1. load a pretrained model + torch_dtype = torch.float + if script_args.model_dtype == "float16": + torch_dtype = torch.float16 + elif script_args.model_dtype == "bfloat16": + torch_dtype = torch.bfloat16 + + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=True, + torch_dtype=torch_dtype, + load_in_4bit=script_args.load_in_4bit, + device_map={"": Accelerator().local_process_index}, + ) + model.config.use_cache = False + + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + tokenizer.pad_token = tokenizer.eos_token + + # 2. Load the Stack-exchange paired dataset + train_dataset = get_stack_exchange_paired(data_dir="data/rl") + train_dataset = train_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length, + num_proc=script_args.num_proc, + ) + + # 3. Load evaluation dataset + eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation") + eval_dataset = eval_dataset.filter( + lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length + and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length, + num_proc=script_args.num_proc, + ) + + # 4. initialize training arguments: + training_args = DPOConfig( + per_device_train_batch_size=script_args.per_device_train_batch_size, + per_device_eval_batch_size=script_args.per_device_eval_batch_size, + max_steps=script_args.max_steps, + logging_steps=script_args.logging_steps, + save_steps=script_args.save_steps, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, + gradient_checkpointing=script_args.gradient_checkpointing, + learning_rate=script_args.learning_rate, + eval_strategy="steps", + eval_steps=script_args.eval_steps, + output_dir=script_args.output_dir, + report_to=script_args.report_to, + lr_scheduler_type=script_args.lr_scheduler_type, + warmup_steps=script_args.warmup_steps, + optim=script_args.optimizer_type, + bf16=True, + remove_unused_columns=False, + run_name="dpo_llama2", + gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant), + seed=script_args.seed, + ) + + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=[ + "q_proj", + "v_proj", + "k_proj", + "out_proj", + "fc_in", + "fc_out", + "wte", + ], + bias="none", + task_type="CAUSAL_LM", + ) + + # 5. initialize the DPO trainer + dpo_trainer = DPOTrainer( + model, + ref_model=None, + args=training_args, + beta=script_args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=tokenizer, + peft_config=peft_config, + max_prompt_length=script_args.max_prompt_length, + max_length=script_args.max_length, + ) + + # 6. train + dpo_trainer.train() + dpo_trainer.save_model(script_args.output_dir) + + # 7. save + output_dir = os.path.join(script_args.output_dir, "final_checkpoint") + dpo_trainer.model.save_pretrained(output_dir) diff --git a/examples/research_projects/stack_llama_2/scripts/requirements.txt b/examples/research_projects/stack_llama_2/scripts/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ca124e58df8e4269a4d44d3ceccd0e2a05ea4fae --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/requirements.txt @@ -0,0 +1,7 @@ +transformers +trl +peft +accelerate +datasets +bitsandbytes +wandb diff --git a/examples/research_projects/stack_llama_2/scripts/sft_llama2.py b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py new file mode 100644 index 0000000000000000000000000000000000000000..dff5b169e84d18075d12de0f2616e11d2fc1ecaa --- /dev/null +++ b/examples/research_projects/stack_llama_2/scripts/sft_llama2.py @@ -0,0 +1,212 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Fine-Tune Llama2-7b on SE paired dataset +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from peft import AutoPeftModelForCausalLM, LoraConfig +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BitsAndBytesConfig, + HfArgumentParser, + is_torch_npu_available, + is_torch_xpu_available, + set_seed, +) + +from trl import SFTConfig, SFTTrainer +from trl.trainer import ConstantLengthDataset + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"}) + subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"}) + split: Optional[str] = field(default="train", metadata={"help": "the split to use"}) + size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"}) + streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"}) + shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"}) + seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"}) + num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"}) + use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"}) + + # LoraConfig + lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + + +parser = HfArgumentParser((ScriptArguments, SFTConfig)) +script_args, training_args = parser.parse_args_into_dataclasses() +peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + task_type="CAUSAL_LM", +) + +if training_args.group_by_length and training_args.packing: + raise ValueError("Cannot use both packing and group by length") + +# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used. +# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`. +if training_args.gradient_checkpointing: + raise ValueError("gradient_checkpointing not supported") + +set_seed(training_args.seed) + + +def chars_token_ratio(dataset, tokenizer, nb_examples=400): + """ + Estimate the average number of characters per token in the dataset. + """ + total_characters, total_tokens = 0, 0 + for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples): + text = prepare_sample_text(example) + total_characters += len(text) + if tokenizer.is_fast: + total_tokens += len(tokenizer(text).tokens()) + else: + total_tokens += len(tokenizer.tokenize(text)) + + return total_characters / total_tokens + + +def print_trainable_parameters(model): + """ + Prints the number of trainable parameters in the model. + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +def prepare_sample_text(example): + """Prepare the text from a sample of the dataset.""" + text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}" + return text + + +def create_datasets(tokenizer, args, seed=None): + dataset = load_dataset( + args.dataset_name, + data_dir=args.subset, + split=args.split, + use_auth_token=True, + num_proc=args.num_workers if not args.streaming else None, + streaming=args.streaming, + ) + if args.streaming: + print("Loading the dataset in streaming mode") + valid_data = dataset.take(args.size_valid_set) + train_data = dataset.skip(args.size_valid_set) + train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed) + else: + dataset = dataset.train_test_split(test_size=0.005, seed=seed) + train_data = dataset["train"] + valid_data = dataset["test"] + print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}") + + chars_per_token = chars_token_ratio(train_data, tokenizer) + print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}") + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + formatting_func=prepare_sample_text, + infinite=True, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + valid_dataset = ConstantLengthDataset( + tokenizer, + valid_data, + formatting_func=prepare_sample_text, + infinite=False, + seq_length=args.seq_length, + chars_per_token=chars_per_token, + ) + return train_dataset, valid_dataset + + +bnb_config = None +if script_args.use_bnb: + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + +base_model = AutoModelForCausalLM.from_pretrained( + script_args.model_name, + quantization_config=bnb_config, + device_map={"": Accelerator().local_process_index}, + trust_remote_code=True, + use_auth_token=True, +) +base_model.config.use_cache = False + + +tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token +tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training + +train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed) + +trainer = SFTTrainer( + model=base_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + max_length=None, + formatting_func=prepare_sample_text, + processing_class=tokenizer, + args=training_args, +) +trainer.train() +trainer.save_model(training_args.output_dir) + +output_dir = os.path.join(training_args.output_dir, "final_checkpoint") +trainer.model.save_pretrained(output_dir) + +# Free memory for merging weights +del base_model +if is_torch_xpu_available(): + torch.xpu.empty_cache() +elif is_torch_npu_available(): + torch.npu.empty_cache() +else: + torch.cuda.empty_cache() + +model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16) +model = model.merge_and_unload() + +output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint") +model.save_pretrained(output_merged_dir, safe_serialization=True) diff --git a/examples/research_projects/toxicity/README.md b/examples/research_projects/toxicity/README.md new file mode 100644 index 0000000000000000000000000000000000000000..85967ab57ec5eeb10ea9eb6e372a62a0522e4d7e --- /dev/null +++ b/examples/research_projects/toxicity/README.md @@ -0,0 +1,7 @@ +# De-detoxifying language models + +To run this code, do the following: + +```shell +ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb +``` diff --git a/examples/research_projects/toxicity/scripts/evaluate-toxicity.py b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..6a1913f3d613d08b241084dae444a5bbc08d1b2e --- /dev/null +++ b/examples/research_projects/toxicity/scripts/evaluate-toxicity.py @@ -0,0 +1,146 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv + +import evaluate +import numpy as np +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_npu_available, is_torch_xpu_available + + +toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement") +ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test") + +parser = argparse.ArgumentParser(description="Evaluate de-toxified models") +parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder") +parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder") +parser.add_argument("--batch_size", default=64, type=int, help="Batch size") +parser.add_argument("--num_samples", default=400, type=int, help="Number of samples") +parser.add_argument("--context_length", default=2000, type=int, help="Number of samples") +parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation") +args = parser.parse_args() + + +if args.model_type == "all": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detoxs", + ] +elif args.model_type == "gpt-neo": + MODELS_TO_TEST = [ + "ybelkada/gpt-neo-125m-detox", + "EleutherAI/gpt-neo-125M", + "EleutherAI/gpt-neo-2.7B", + "ybelkada/gpt-neo-2.7B-detox", + ] +elif args.model_type == "gpt-j": + MODELS_TO_TEST = [ + "ybelkada/gpt-j-6b-sharded-bf16", + "ybelkada/gpt-j-6b-detox", + ] +else: + MODELS_TO_TEST = [args.model_type] +NUM_SAMPLES = args.num_samples +BATCH_SIZE = args.batch_size +output_file = args.output_file +max_new_tokens = args.max_new_tokens +context_length = args.context_length +if is_torch_xpu_available(): + device = torch.xpu.current_device() +elif is_torch_npu_available(): + device = torch.npu.current_device() +else: + device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" + +# consider only toxic prompts +ds = ds.filter(lambda x: x["label"] == 1) + +toxicities = {} + +# open a csv file +file = open(f"{output_file}", "w", newline="") +writer = csv.writer(file) +# add first rows +writer.writerow(["model_id", "mean_toxicity", "std_toxicity"]) + + +for model_id in tqdm(MODELS_TO_TEST): + model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, torch_dtype=torch.bfloat16) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "left" + input_texts = [] + + for i, example in enumerate(ds): + # set seed + torch.manual_seed(42) + + input_text = example["comment_text"] + input_texts.append(input_text[:2000]) + + if i > NUM_SAMPLES: + break + + if (i + 1) % BATCH_SIZE == 0: + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + inputs.input_ids = inputs.input_ids[:context_length] + inputs.attention_mask = inputs.attention_mask[:context_length] + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [ + generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts) + ] + toxicity_score = toxicity.compute(predictions=generated_texts) + input_texts = [] + + if model_id not in toxicities: + toxicities[model_id] = [] + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # last batch + inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device) + outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)] + toxicity_score = toxicity.compute(predictions=generated_texts) + toxicities[model_id].extend(toxicity_score["toxicity"]) + + # compute mean & std using np + mean = np.mean(toxicities[model_id]) + std = np.std(toxicities[model_id]) + + # save to file + writer.writerow([model_id, mean, std]) + + # print + print(f"Model: {model_id} - Mean: {mean} - Std: {std}") + + model = None + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + +# close file +file.close() diff --git a/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py new file mode 100644 index 0000000000000000000000000000000000000000..edab2a669b5dcd1bf13ae2c0bf0e9fd44dcbb39d --- /dev/null +++ b/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py @@ -0,0 +1,239 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from torch.optim import Adam +from tqdm import tqdm +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + RobertaForSequenceClassification, + RobertaTokenizer, + set_seed, +) + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model +from trl.core import LengthSampler + + +tqdm.pandas() + +######################################################################## +# This is a fully working simple example to use trl with accelerate. +# +# This example fine-tunes a GPTJ model to generate less toxic contents +# by using allenai/real-toxicity-prompts dataset. We use PPO +# (proximal policy optimization) to optimize the model. +# in any of the following settings (with the same script): +# - single CPU or single GPU +# - multi GPUS (using PyTorch distributed mode) +# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2) +# - fp16 (mixed-precision) or fp32 (normal precision) +# +# To run it in each of these various modes, first initialize the accelerate +# configuration with `accelerate config` +# +######################################################################## + + +# We first define the configuration of the experiment, defining the model, the dataset, +# the training parameters, and the PPO parameters. +# Check the default arguments in the `PPOConfig` class for more details. +# If you want to log with tensorboard, add the kwarg +# `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig. +@dataclass +class ScriptArguments: + """ + The name of the Casual LM model we wish to fine-tune with PPO + """ + + # NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode + # models like gpt-neo* models are more suitable. + model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=1, metadata={"help": "the number of gradient accumulation steps"} + ) + model_save_path: Optional[str] = field( + default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final", + metadata={"help": "the path to save the model"}, + ) + + +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +config = PPOConfig( + model_name=script_args.model_name, + learning_rate=script_args.learning_rate, + log_with=script_args.log_with, + ppo_epochs=100, + mini_batch_size=script_args.mini_batch_size, + batch_size=script_args.batch_size, + gradient_accumulation_steps=script_args.gradient_accumulation_steps, +) + + +# Below is an example function to build the dataset. In our case, we use the IMDB dataset +# from the `datasets` library. One should customize this function to train the model on +# its own dataset. +def build_dataset( + config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10 +): + """ + Build dataset for training. This builds the dataset from `load_dataset`, one should + customize this function to train the model on its own dataset. + + Args: + dataset_name (`str`): + The name of the dataset to be loaded. + + Returns: + dataloader (`torch.utils.data.DataLoader`): + The dataloader for the dataset. + """ + tokenizer = AutoTokenizer.from_pretrained(config.model_name) + tokenizer.pad_token = tokenizer.eos_token + + ds = load_dataset(dataset_name, split="train") + + def filter_fn(sample): + toxicity = sample["prompt"]["toxicity"] + return toxicity is not None and toxicity > 0.3 + + ds = ds.filter(filter_fn, batched=False) + + input_size = LengthSampler(input_min_text_length, input_max_text_length) + + def tokenize(sample): + prompt = sample["prompt"]["text"] + continuation = sample["continuation"]["text"] + + sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()] + sample["query"] = tokenizer.decode(sample["input_ids"]) + return sample + + ds = ds.map(tokenize, batched=False) + ds.set_format(type="torch") + + ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"] + + return ds + + +# We retrieve the dataloader by calling the `build_dataset` function. +min_input_length = 30 +max_input_length = 40 +dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length) + + +def collator(data): + return {key: [d[key] for d in data] for key in data[0]} + + +# set seed before initializing value head for deterministic eval +set_seed(config.seed) + +# Now let's build the model, the reference model, and the tokenizer. We first load the model +# in bfloat16 to save memory using `transformers`. +model = AutoModelForCausalLM.from_pretrained(config.model_name, torch_dtype=torch.bfloat16) +# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`. +model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + +# We create a reference model by sharing 20 layers +ref_model = create_reference_model(model, num_shared_layers=20) + +# We make sure to use `Adam` optimizer on the model parameters that require gradients. +optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate) + +# GPT-2 / GPT-J tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token. +# only for this model. +tokenizer = AutoTokenizer.from_pretrained(config.model_name) +tokenizer.pad_token = tokenizer.eos_token + +# We then build the PPOTrainer, passing the model, the reference model, the tokenizer +ppo_trainer = PPOTrainer( + config, + model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=dataset, + data_collator=collator, + optimizer=optimizer, +) + +# We then build the reward pipeline, we will use the toxicity model to compute the reward. +# We first load the toxicity model and tokenizer. +toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target" +toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id) +# We load the toxicity model in fp16 to save memory. +toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, torch_dtype=torch.float16).to( + ppo_trainer.accelerator.device +) + + +# We then define the arguments to pass to the `generate` function. These arguments +# are passed to the `generate` function of the PPOTrainer, which is a wrapper around +# the `generate` function of the trained model. +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, +} +output_min_length = 20 +output_max_length = 30 +output_length_sampler = LengthSampler(output_min_length, output_max_length) + +model_save_path = script_args.model_save_path + +for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)): + query_tensors = batch["input_ids"] + + # Get response from the policy model + response_tensors = [] + for query in query_tensors: + gen_len = output_length_sampler() + generation_kwargs["max_new_tokens"] = gen_len + response = ppo_trainer.generate(query, **generation_kwargs) + response_tensors.append(response.squeeze()[-gen_len:]) + batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors] + + # Compute sentiment score + texts = batch["response"] + toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to( + ppo_trainer.accelerator.device + ) + logits = toxicity_model(**toxicity_inputs).logits.float() + toxicity_labels = (logits[:, 0]).tolist() + + rewards = [torch.tensor(output) for output in toxicity_labels] + + # Run PPO step + stats = ppo_trainer.step(query_tensors, response_tensors, rewards) + ppo_trainer.log_stats(stats, batch, rewards) + + # Save model every 100 epochs + if epoch % 100 == 0: + if ppo_trainer.accelerator.is_main_process: + ppo_trainer.save_pretrained(model_save_path) diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py new file mode 100644 index 0000000000000000000000000000000000000000..d4772942a9b206dc9b331e23c47d7e602b494454 --- /dev/null +++ b/examples/scripts/alignprop.py @@ -0,0 +1,154 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Total Batch size = 128 = 4 (num_gpus) * 8 (per_device_batch) * 4 (accumulation steps) +Feel free to reduce batch size or increasing truncated_rand_backprop_min to a higher value to reduce memory usage. + +CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/scripts/alignprop.py \ + --num_epochs=20 \ + --train_gradient_accumulation_steps=4 \ + --sample_num_steps=50 \ + --train_batch_size=8 \ + --tracker_project_name="stable_diffusion_training" \ + --log_with="wandb" + +""" + +from dataclasses import dataclass, field + +import numpy as np +from transformers import HfArgumentParser + +from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline +from trl.models.auxiliary_modules import aesthetic_scorer + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + pretrained_model (`str`, *optional*, defaults to `"runwayml/stable-diffusion-v1-5"`): + Pretrained model to use. + pretrained_revision (`str`, *optional*, defaults to `"main"`): + Pretrained model revision to use. + hf_hub_model_id (`str`, *optional*, defaults to `"alignprop-finetuned-stable-diffusion"`): + HuggingFace repo to save model weights to. + hf_hub_aesthetic_model_id (`str`, *optional*, defaults to `"trl-lib/ddpo-aesthetic-predictor"`): + Hugging Face model ID for aesthetic scorer model weights. + hf_hub_aesthetic_model_filename (`str`, *optional*, defaults to `"aesthetic-model.pth"`): + Hugging Face model filename for aesthetic scorer model weights. + use_lora (`bool`, *optional*, defaults to `True`): + Whether to use LoRA. + """ + + pretrained_model: str = field( + default="runwayml/stable-diffusion-v1-5", metadata={"help": "Pretrained model to use."} + ) + pretrained_revision: str = field(default="main", metadata={"help": "Pretrained model revision to use."}) + hf_hub_model_id: str = field( + default="alignprop-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to."} + ) + hf_hub_aesthetic_model_id: str = field( + default="trl-lib/ddpo-aesthetic-predictor", + metadata={"help": "Hugging Face model ID for aesthetic scorer model weights."}, + ) + hf_hub_aesthetic_model_filename: str = field( + default="aesthetic-model.pth", + metadata={"help": "Hugging Face model filename for aesthetic scorer model weights."}, + ) + use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) + + +# list of example prompts to feed stable diffusion +animals = [ + "cat", + "dog", + "horse", + "monkey", + "rabbit", + "zebra", + "spider", + "bird", + "sheep", + "deer", + "cow", + "goat", + "lion", + "frog", + "chicken", + "duck", + "goose", + "bee", + "pig", + "turkey", + "fly", + "llama", + "camel", + "bat", + "gorilla", + "hedgehog", + "kangaroo", +] + + +def prompt_fn(): + return np.random.choice(animals), {} + + +def image_outputs_logger(image_pair_data, global_step, accelerate_logger): + # For the sake of this example, we will only log the last batch of images + # and associated data + result = {} + images, prompts, _ = [image_pair_data["images"], image_pair_data["prompts"], image_pair_data["rewards"]] + for i, image in enumerate(images[:4]): + prompt = prompts[i] + result[f"{prompt}"] = image.unsqueeze(0).float() + accelerate_logger.log_images( + result, + step=global_step, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, AlignPropConfig)) + script_args, training_args = parser.parse_args_into_dataclasses() + training_args.project_kwargs = { + "logging_dir": "./logs", + "automatic_checkpoint_naming": True, + "total_limit": 5, + "project_dir": "./save", + } + + pipeline = DefaultDDPOStableDiffusionPipeline( + script_args.pretrained_model, + pretrained_model_revision=script_args.pretrained_revision, + use_lora=script_args.use_lora, + ) + trainer = AlignPropTrainer( + training_args, + aesthetic_scorer(script_args.hf_hub_aesthetic_model_id, script_args.hf_hub_aesthetic_model_filename), + prompt_fn, + pipeline, + image_samples_hook=image_outputs_logger, + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py new file mode 100644 index 0000000000000000000000000000000000000000..0cffabca8791575b45b2f13155f114a1711ef342 --- /dev/null +++ b/examples/scripts/bco.py @@ -0,0 +1,167 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the BCO training script with the commands below. In general, the optimal configuration for BCO will be similar to that of KTO. + +# Full training: +python examples/scripts/bco.py \ + --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ + --trust_remote_code \ + --dataset_name trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 32 \ + --num_train_epochs 1 \ + --learning_rate 1e-6 \ + --gradient_checkpointing \ + --gradient_accumulation_steps 1 \ + --logging_steps 0.01 \ + --eval_steps 0.2 \ + --save_strategy no \ + --output_dir=bco-aligned-model \ + --logging_first_step \ + --max_length 2048 \ + --max_prompt_length 1536 \ + --max_completion_length 1024 \ + --no_remove_unused_columns \ + --warmup_ratio 0.1 \ + --bf16 \ + --report_to wandb + +# QLoRA: +python examples/scripts/bco.py \ + --model_name_or_path=nnheui/stablelm-2-1_6b-sft-full \ + --per_device_train_batch_size 16 \ + --per_device_eval_batch_size 32 \ + --num_train_epochs 1 \ + --learning_rate 1e-6 \ + --gradient_checkpointing \ + --gradient_accumulation_steps 1 \ + --logging_steps 0.01 \ + --eval_steps 0.2 \ + --save_strategy no \ + --output_dir=bco-aligned-model-lora \ + --logging_first_step \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --max_length 2048 \ + --max_prompt_length 1536 \ + --max_completion_length 1024 \ + --no_remove_unused_columns \ + --warmup_ratio 0.1 \ + --bf16 \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +from functools import partial + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from datasets import load_dataset +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, PreTrainedModel + +from trl import BCOConfig, BCOTrainer, ModelConfig, ScriptArguments, get_peft_config, setup_chat_format + + +def embed_prompt(input_ids: torch.LongTensor, attention_mask: torch.LongTensor, model: PreTrainedModel): + """ + Borrowed from https://huggingface.co/nomic-ai/nomic-embed-text-v1.5#transformers + """ + + def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + with torch.no_grad(): + model_output = model(input_ids=input_ids, attention_mask=attention_mask) + embeddings = mean_pooling(model_output, attention_mask) + + matryoshka_dim = 512 + # normalize embeddings + embeddings = F.normalize(embeddings, p=2, dim=1) + embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) + embeddings = embeddings[:, :matryoshka_dim] + + return embeddings + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, BCOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + accelerator = Accelerator() + embedding_model = AutoModel.from_pretrained( + "nomic-ai/nomic-embed-text-v1.5", + trust_remote_code=model_args.trust_remote_code, + safe_serialization=True, + torch_dtype=torch.bfloat16, + device_map="auto", + ) + embedding_model = accelerator.prepare_model(embedding_model) + embedding_tokenizer = AutoTokenizer.from_pretrained( + "bert-base-uncased", trust_remote_code=model_args.trust_remote_code + ) + embedding_func = partial( + embed_prompt, + model=embedding_model, + ) + + # Initialize the BCO trainer + trainer = BCOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + embedding_func=embedding_func, + embedding_tokenizer=embedding_tokenizer, + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py new file mode 100644 index 0000000000000000000000000000000000000000..a90d875972492396888d66711230b267ad83ff68 --- /dev/null +++ b/examples/scripts/cpo.py @@ -0,0 +1,106 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the CPO training script with the following command with some example arguments. +In general, the optimal configuration for CPO will be similar to that of DPO: + +# regular: +python examples/scripts/cpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-6 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-aligned-cpo" \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns + +# peft: +python examples/scripts/cpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-5 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-lora-aligned-cpo" \ + --optim rmsprop \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r=16 \ + --lora_alpha=16 +""" + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + ################ + # Model & Tokenizer + ################ + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + ################ + # Training + ################ + trainer = CPOTrainer( + model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # train and save the model + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py new file mode 100644 index 0000000000000000000000000000000000000000..fec079c95a10472809d045b3f786945241727a63 --- /dev/null +++ b/examples/scripts/ddpo.py @@ -0,0 +1,234 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +python examples/scripts/ddpo.py \ + --num_epochs=200 \ + --train_gradient_accumulation_steps=1 \ + --sample_num_steps=50 \ + --sample_batch_size=6 \ + --train_batch_size=3 \ + --sample_num_batches_per_epoch=4 \ + --per_prompt_stat_tracking=True \ + --per_prompt_stat_tracking_buffer_size=32 \ + --tracker_project_name="stable_diffusion_training" \ + --log_with="wandb" +""" + +import os +from dataclasses import dataclass, field + +import numpy as np +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel, CLIPProcessor, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available + +from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + pretrained_model (`str`, *optional*, defaults to `"runwayml/stable-diffusion-v1-5"`): + Pretrained model to use. + pretrained_revision (`str`, *optional*, defaults to `"main"`): + Pretrained model revision to use. + hf_hub_model_id (`str`, *optional*, defaults to `"ddpo-finetuned-stable-diffusion"`): + HuggingFace repo to save model weights to. + hf_hub_aesthetic_model_id (`str`, *optional*, defaults to `"trl-lib/ddpo-aesthetic-predictor"`): + Hugging Face model ID for aesthetic scorer model weights. + hf_hub_aesthetic_model_filename (`str`, *optional*, defaults to `"aesthetic-model.pth"`): + Hugging Face model filename for aesthetic scorer model weights. + use_lora (`bool`, *optional*, defaults to `True`): + Whether to use LoRA. + """ + + pretrained_model: str = field( + default="runwayml/stable-diffusion-v1-5", metadata={"help": "Pretrained model to use."} + ) + pretrained_revision: str = field(default="main", metadata={"help": "Pretrained model revision to use."}) + hf_hub_model_id: str = field( + default="ddpo-finetuned-stable-diffusion", metadata={"help": "HuggingFace repo to save model weights to."} + ) + hf_hub_aesthetic_model_id: str = field( + default="trl-lib/ddpo-aesthetic-predictor", + metadata={"help": "Hugging Face model ID for aesthetic scorer model weights."}, + ) + hf_hub_aesthetic_model_filename: str = field( + default="aesthetic-model.pth", + metadata={"help": "Hugging Face model filename for aesthetic scorer model weights."}, + ) + use_lora: bool = field(default=True, metadata={"help": "Whether to use LoRA."}) + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + @torch.no_grad() + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + @torch.no_grad() + def __call__(self, images): + device = next(self.parameters()).device + inputs = self.processor(images=images, return_tensors="pt") + inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()} + embed = self.clip.get_image_features(**inputs) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + return self.mlp(embed).squeeze(1) + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + if is_torch_npu_available(): + scorer = scorer.npu() + elif is_torch_xpu_available(): + scorer = scorer.xpu() + else: + scorer = scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images * 255).round().clamp(0, 255).to(torch.uint8) + scores = scorer(images) + return scores, {} + + return _fn + + +# list of example prompts to feed stable diffusion +animals = [ + "cat", + "dog", + "horse", + "monkey", + "rabbit", + "zebra", + "spider", + "bird", + "sheep", + "deer", + "cow", + "goat", + "lion", + "frog", + "chicken", + "duck", + "goose", + "bee", + "pig", + "turkey", + "fly", + "llama", + "camel", + "bat", + "gorilla", + "hedgehog", + "kangaroo", +] + + +def prompt_fn(): + return np.random.choice(animals), {} + + +def image_outputs_logger(image_data, global_step, accelerate_logger): + # For the sake of this example, we will only log the last batch of images + # and associated data + result = {} + images, prompts, _, rewards, _ = image_data[-1] + + for i, image in enumerate(images): + prompt = prompts[i] + reward = rewards[i].item() + result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float() + + accelerate_logger.log_images( + result, + step=global_step, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, DDPOConfig)) + script_args, training_args = parser.parse_args_into_dataclasses() + training_args.project_kwargs = { + "logging_dir": "./logs", + "automatic_checkpoint_naming": True, + "total_limit": 5, + "project_dir": "./save", + } + + pipeline = DefaultDDPOStableDiffusionPipeline( + script_args.pretrained_model, + pretrained_model_revision=script_args.pretrained_revision, + use_lora=script_args.use_lora, + ) + + trainer = DDPOTrainer( + training_args, + aesthetic_scorer(script_args.hf_hub_aesthetic_model_id, script_args.hf_hub_aesthetic_model_filename), + prompt_fn, + pipeline, + image_samples_hook=image_outputs_logger, + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..0cadeb9a74ec5896264d5b87478a7b88a74ae217 --- /dev/null +++ b/examples/scripts/dpo.py @@ -0,0 +1,17 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # +############################################################################################### diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py new file mode 100644 index 0000000000000000000000000000000000000000..33d5cb18242bba79071626f356f94a0a52b89795 --- /dev/null +++ b/examples/scripts/dpo_online.py @@ -0,0 +1,148 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +python examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 16 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 + +With LoRA: +python examples/scripts/dpo_online.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-6 \ + --output_dir pythia-1b-tldr-online-dpo \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --use_peft +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig + +from trl import ( + HfPairwiseJudge, + LogCompletionsCallback, + ModelConfig, + OnlineDPOConfig, + OnlineDPOTrainer, + OpenAIPairwiseJudge, + PairRMJudge, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge} + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, OnlineDPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + if training_args.reward_model_path is not None: + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + num_labels=1, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + reward_tokenizer = AutoTokenizer.from_pretrained( + training_args.reward_model_path, + trust_remote_code=model_args.trust_remote_code, + truncation=True, + truncation_side="left", # since we judge the completion, truncating left is more appropriate + ) + else: + reward_model = None + reward_tokenizer = None + + if training_args.judge is not None: + judge_cls = JUDGES[training_args.judge] + judge = judge_cls() + else: + judge = None + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + padding_side="left", + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + trainer = OnlineDPOTrainer( + model=model, + reward_model=reward_model, + judge=judge, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + reward_processing_class=reward_tokenizer, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/dpo_vlm.py b/examples/scripts/dpo_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2db5359a6ae8d29387351f6008aa4f1da2f915 --- /dev/null +++ b/examples/scripts/dpo_vlm.py @@ -0,0 +1,152 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Without dataset streaming: + +``` +accelerate launch examples/scripts/dpo_vlm.py \ + --dataset_name HuggingFaceH4/rlaif-v_formatted \ + --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 32 \ + --dataset_num_proc 32 \ + --output_dir dpo_idefics_rlaif-v \ + --bf16 \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --use_peft \ + --lora_target_modules=all-linear \ + --report_to wandb +``` + +With dataset streaming: + +``` +accelerate launch examples/scripts/dpo_vlm.py \ + --dataset_name HuggingFaceH4/rlaif-v_formatted \ + --dataset_streaming \ + --model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \ + --per_device_train_batch_size 2 \ + --max_steps 100 \ + --gradient_accumulation_steps 32 \ + --dataset_num_proc 32 \ + --output_dir dpo_idefics_rlaif-v \ + --bf16 \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --use_peft \ + --lora_target_modules=all-linear \ + --report_to wandb +``` +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForVision2Seq, AutoProcessor + +from trl import ( + DPOConfig, + DPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, DPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + else: + ref_model = None + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, do_image_splitting=False + ) + tokenizer = processor.tokenizer + + # Set up the chat template + if model.config.model_type == "idefics2": + pass # the processor already has a valid chat template + elif model.config.model_type == "paligemma": + processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + elif model.config.model_type == "llava": + processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ################ + # Dataset + ################ + dataset = load_dataset( + script_args.dataset_name, + name=script_args.dataset_config, + streaming=script_args.dataset_streaming, + ) + + ################ + # Training + ################ + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=processor, + peft_config=peft_config, + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/evals/judge_tldr.py b/examples/scripts/evals/judge_tldr.py new file mode 100644 index 0000000000000000000000000000000000000000..b26841ae46dbed8734b9eeec25cf92d35e5905b9 --- /dev/null +++ b/examples/scripts/evals/judge_tldr.py @@ -0,0 +1,102 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from transformers import HfArgumentParser +from vllm import LLM, SamplingParams + +from trl import HfPairwiseJudge, OpenAIPairwiseJudge + + +""" +Examples: + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000 +Model win rate: 31.40% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 +Model win rate: 51.60% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 51.20% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000 +Model win rate: 46.30% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 +Model win rate: 52.50% + +python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 +Model win rate: 63.00% +""" + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + model_name_or_path (`str`): + Model name or path to the model to evaluate. + judge_model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): + Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or + 'meta-llama/Meta-Llama-3-70B-Instruct'. + num_examples (`int` or `None`, *optional*, defaults to `None`): + Number of examples to evaluate. + """ + + model_name_or_path: str = field(metadata={"help": "Model name or path to the model to evaluate."}) + judge_model: str = field( + default="meta-llama/Meta-Llama-3-70B-Instruct", + metadata={ + "help": "Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or " + "'meta-llama/Meta-Llama-3-70B-Instruct'." + }, + ) + num_examples: Optional[int] = field(default=None, metadata={"help": "Number of examples to evaluate."}) + + +# Parse the arguments +parser = HfArgumentParser(ScriptArguments) +script_args = parser.parse_args_into_dataclasses()[0] + +# Load the dataset +dataset = load_dataset("trl-lib/tldr", split="validation") +if script_args.num_examples is not None: + dataset = dataset.select(range(script_args.num_examples)) + +# Extract the prompts and reference completions +prompts = dataset["prompt"] +reference_completions = dataset["completion"] + +# Generate the model completions +sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length +llm = LLM(model=script_args.model_name_or_path, tensor_parallel_size=1) +outputs = llm.generate(prompts, sampling_params) +model_completions = [output.outputs[0].text.strip() for output in outputs] + +# Judge the outputs +if "gpt" in script_args.judge_model: + judge = OpenAIPairwiseJudge(script_args.judge_model) +else: + judge = HfPairwiseJudge(script_args.judge_model) + +completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)] +best_idxs = judge.judge(prompts, completions) +model_win_rate = best_idxs.count(1) / len(best_idxs) +print(f"Model win rate: {model_win_rate * 100:.2f}%") diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py new file mode 100644 index 0000000000000000000000000000000000000000..10b12c324f479b580cb4be055e071c238b090e75 --- /dev/null +++ b/examples/scripts/gkd.py @@ -0,0 +1,133 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training: +python examples/scripts/gkd.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-5 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gkd-model \ + --logging_steps 10 \ + --num_train_epochs 1 \ + --push_to_hub \ + --gradient_checkpointing + +# LoRA: +python examples/scripts/gkd.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --dataset_name trl-lib/chatbot_arena_completions \ + --learning_rate 2e-4 \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 8 \ + --output_dir gkd-model \ + --logging_steps 10 \ + --num_train_epochs 1 \ + --push_to_hub \ + --gradient_checkpointing \ + --use_peft \ + --lora_r 64 \ + --lora_alpha 16 +""" + +from datasets import load_dataset +from transformers import AutoTokenizer, GenerationConfig + +from trl import ( + GKDConfig, + GKDTrainer, + LogCompletionsCallback, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + ################ + # Model & Tokenizer + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.model_init_kwargs = model_kwargs + + teacher_model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, + use_cache=True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + training_args.teacher_model_init_kwargs = teacher_model_kwargs + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + padding_side="left", + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + trainer = GKDTrainer( + model=model_args.model_name_or_path, + teacher_model=training_args.teacher_model_name_or_path, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py new file mode 100644 index 0000000000000000000000000000000000000000..b88ac3ffeada5b8df357872161c9ca0832501356 --- /dev/null +++ b/examples/scripts/kto.py @@ -0,0 +1,113 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. + +# Full training: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 16 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step + +# QLoRA: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model-lora \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + get_peft_config, + setup_chat_format, +) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Initialize the KTO trainer + trainer = KTOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py new file mode 100644 index 0000000000000000000000000000000000000000..ff8ea4621043a8efa2328ee2a37cc8c373d75f2d --- /dev/null +++ b/examples/scripts/nash_md.py @@ -0,0 +1,145 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +python examples/scripts/nash_md.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-nash-md \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub + + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/nash_md.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-nash-md \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig + +from trl import ( + HfPairwiseJudge, + LogCompletionsCallback, + ModelConfig, + NashMDConfig, + NashMDTrainer, + OpenAIPairwiseJudge, + PairRMJudge, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge} + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, NashMDConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + if training_args.reward_model_path is not None: + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + num_labels=1, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + else: + reward_model = None + + if training_args.judge is not None: + judge_cls = JUDGES[training_args.judge] + judge = judge_cls() + else: + judge = None + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + trainer = NashMDTrainer( + model=model, + ref_model=ref_model, + reward_model=reward_model, + judge=judge, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py new file mode 100644 index 0000000000000000000000000000000000000000..10a56787480a4ca566b76c5b341c3ed1f446a0d8 --- /dev/null +++ b/examples/scripts/orpo.py @@ -0,0 +1,106 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the ORPO training script with the following command with some example arguments. +In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: + +# regular: +python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-6 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-aligned-orpo" \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns + +# peft: +python examples/scripts/orpo.py \ + --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ + --model_name_or_path=gpt2 \ + --per_device_train_batch_size 4 \ + --max_steps 1000 \ + --learning_rate 8e-5 \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir="gpt2-lora-aligned-orpo" \ + --optim rmsprop \ + --warmup_steps 150 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r=16 \ + --lora_alpha=16 +""" + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + ################ + # Model & Tokenizer + ################ + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + ################ + # Training + ################ + trainer = ORPOTrainer( + model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # train and save the model + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py new file mode 100644 index 0000000000000000000000000000000000000000..4374f0baa3b62e66837508e13bd4397ea8930f6f --- /dev/null +++ b/examples/scripts/ppo/ppo.py @@ -0,0 +1,170 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import torch +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ( + ModelConfig, + PPOConfig, + PPOTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +""" +python -i examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --missing_eos_penalty 1.0 + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/ppo/ppo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --output_dir models/minimal/ppo \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ + --local_rollout_forward_batch_size 1 \ + --missing_eos_penalty 1.0 +""" + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + value_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + else: + ref_policy = None + + ################ + # Dataset + ################ + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split + ) + eval_samples = 100 + train_dataset = dataset.select(range(len(dataset) - eval_samples)) + eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + dataset_text_field = "prompt" + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + + ################ + # Training + ################ + trainer = PPOTrainer( + args=training_args, + processing_class=tokenizer, + model=policy, + ref_model=ref_policy, + reward_model=reward_model, + value_model=value_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py new file mode 100644 index 0000000000000000000000000000000000000000..03cc2e20a9ef5d573f0718dc02380c73ec17c0b6 --- /dev/null +++ b/examples/scripts/ppo/ppo_tldr.py @@ -0,0 +1,179 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +import torch +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ( + ModelConfig, + PPOConfig, + PPOTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +""" +python examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo_tldr \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --missing_eos_penalty 1.0 \ + --stop_token eos \ + --response_length 53 \ + --eval_strategy steps \ + --eval_steps 100 + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/ppo/ppo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ + --output_dir models/minimal/ppo_tldr \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --local_rollout_forward_batch_size 16 \ + --missing_eos_penalty 1.0 \ + --stop_token eos \ + --eval_strategy steps \ + --eval_steps 100 +""" + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + value_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + else: + ref_policy = None + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + input_ids = tokenizer.apply_chat_template( + element["messages"][:1], + padding=False, + add_generation_prompt=True, + ) + return {"input_ids": input_ids, "lengths": len(input_ids)} + + return dataset.map( + tokenize, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + if eval_dataset is not None: + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + # filtering + train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) + + assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" + ################ + # Training + ################ + trainer = PPOTrainer( + args=training_args, + processing_class=tokenizer, + model=policy, + ref_model=ref_policy, + reward_model=reward_model, + value_model=value_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py new file mode 100644 index 0000000000000000000000000000000000000000..86ce316c11a66949d50e80a961b16504aa5c5272 --- /dev/null +++ b/examples/scripts/prm.py @@ -0,0 +1,130 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Full training: +python examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/prm800k \ + --output_dir Qwen2-0.5B-Reward \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-5 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 + +LoRA: +python examples/scripts/prm.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/prm800k \ + --output_dir Qwen2-0.5B-Reward-LoRA \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-4 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import warnings + +import torch +from datasets import load_dataset +from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser + +from trl import ( + ModelConfig, + PRMConfig, + PRMTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, PRMConfig, ModelConfig)) + script_args, training_args, model_config = parser.parse_args_into_dataclasses() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + use_cache=False if training_args.gradient_checkpointing else True, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True + ) + model = AutoModelForTokenClassification.from_pretrained( + model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + # Align padding tokens between tokenizer and model + model.config.pad_token_id = tokenizer.pad_token_id + + if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS": + warnings.warn( + "You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs" + " Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT.", + UserWarning, + ) + + ############## + # Load dataset + ############## + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + dataset = dataset.filter(lambda x: len(x["completions"]) > 0) + + ########## + # Training + ########## + trainer = PRMTrainer( + model=model, + processing_class=tokenizer, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], + peft_config=get_peft_config(model_config), + ) + trainer.train() + + ############################ + # Save model and push to Hub + ############################ + trainer.save_model(training_args.output_dir) + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..14ca8367fe5fa918da9fb7d43647edaec243de78 --- /dev/null +++ b/examples/scripts/reward_modeling.py @@ -0,0 +1,136 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Full training: +python examples/scripts/reward_modeling.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --output_dir Qwen2-0.5B-Reward \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-5 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --max_length 2048 + +LoRA: +python examples/scripts/reward_modeling.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --output_dir Qwen2-0.5B-Reward-LoRA \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --gradient_checkpointing True \ + --learning_rate 1.0e-4 \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --max_length 2048 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +""" + +import warnings + +import torch +from datasets import load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser + +from trl import ( + ModelConfig, + RewardConfig, + RewardTrainer, + ScriptArguments, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + setup_chat_format, +) + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + + ################ + # Model & Tokenizer + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + use_cache=False if training_args.gradient_checkpointing else True, + torch_dtype=torch_dtype, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + # Align padding tokens between tokenizer and model + model.config.pad_token_id = tokenizer.pad_token_id + + # If post-training a base model, use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS": + warnings.warn( + "You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs" + " Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.", + UserWarning, + ) + + ############## + # Load dataset + ############## + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ########## + # Training + ########## + trainer = RewardTrainer( + model=model, + processing_class=tokenizer, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + peft_config=get_peft_config(model_args), + ) + trainer.train() + + ############################ + # Save model and push to Hub + ############################ + trainer.save_model(training_args.output_dir) + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb5527cff8981f18ce15893342d78e8fe4f6013 --- /dev/null +++ b/examples/scripts/rloo/rloo.py @@ -0,0 +1,141 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +""" +python -i examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 64 \ + --gradient_accumulation_steps 1 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --missing_eos_penalty 1.0 + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/rloo/rloo.py \ + --dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \ + --dataset_train_split descriptiveness \ + --output_dir models/minimal/rloo \ + --rloo_k 2 \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 16 \ + --total_episodes 10000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path EleutherAI/pythia-1b-deduped \ + --reward_model_path EleutherAI/pythia-1b-deduped \ + --local_rollout_forward_batch_size 1 \ + --missing_eos_penalty 1.0 +""" + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + ################ + # Dataset + ################ + dataset = load_dataset( + script_args.dataset_name, name=script_args.dataset_config, split=script_args.dataset_train_split + ) + eval_samples = 100 + train_dataset = dataset.select(range(len(dataset) - eval_samples)) + eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + dataset_text_field = "prompt" + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + outputs = tokenizer( + element[dataset_text_field], + padding=False, + ) + return {"input_ids": outputs["input_ids"]} + + return dataset.map( + tokenize, + batched=True, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + + ################ + # Training + ################ + trainer = RLOOTrainer( + config=training_args, + processing_class=tokenizer, + policy=policy, + ref_policy=ref_policy, + reward_model=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py new file mode 100644 index 0000000000000000000000000000000000000000..0c3cbabfe56b3a588b8705cbef1f78077b2fdd5b --- /dev/null +++ b/examples/scripts/rloo/rloo_tldr.py @@ -0,0 +1,143 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil + +from accelerate import PartialState +from datasets import load_dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + HfArgumentParser, +) + +from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +""" +python examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 64 \ + --total_episodes 30000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --missing_eos_penalty 1.0 \ + --stop_token eos \ + --response_length 53 + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/rloo/rloo_tldr.py \ + --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \ + --dataset_test_split validation \ + --output_dir models/minimal/rloo_tldr \ + --num_ppo_epochs 1 \ + --num_mini_batches 1 \ + --learning_rate 3e-6 \ + --per_device_train_batch_size 16 \ + --gradient_accumulation_steps 4 \ + --total_episodes 1000000 \ + --model_name_or_path EleutherAI/pythia-1b-deduped \ + --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --local_rollout_forward_batch_size 16 \ + --missing_eos_penalty 1.0 \ + --stop_token eos +""" + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + # remove output_dir if exists + shutil.rmtree(training_args.output_dir, ignore_errors=True) + + ################ + # Model & Tokenizer + ################ + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + ref_policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + policy = AutoModelForCausalLM.from_pretrained( + training_args.sft_model_path, trust_remote_code=model_args.trust_remote_code + ) + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + train_dataset = dataset[script_args.dataset_train_split] + eval_dataset = dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None + + def prepare_dataset(dataset, tokenizer): + """pre-tokenize the dataset before training; only collate during training""" + + def tokenize(element): + input_ids = tokenizer.apply_chat_template( + element["messages"][:1], + padding=False, + add_generation_prompt=True, + ) + return {"input_ids": input_ids, "lengths": len(input_ids)} + + return dataset.map( + tokenize, + remove_columns=dataset.column_names, + num_proc=training_args.dataset_num_proc, + ) + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().local_main_process_first(): + train_dataset = prepare_dataset(train_dataset, tokenizer) + eval_dataset = prepare_dataset(eval_dataset, tokenizer) + # filtering + train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) + eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc) + + assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token" + ################ + # Training + ################ + trainer = RLOOTrainer( + config=training_args, + processing_class=tokenizer, + policy=policy, + ref_policy=ref_policy, + reward_model=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + trainer.generate_completions() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..3a6e26c4b439a0b03154de6a250760e67289eeca --- /dev/null +++ b/examples/scripts/sft.py @@ -0,0 +1,17 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +############################################################################################### +# This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # +############################################################################################### diff --git a/examples/scripts/sft_gemma3.py b/examples/scripts/sft_gemma3.py new file mode 100644 index 0000000000000000000000000000000000000000..c19e2095ab95043a9c3584679819cbfe801bde6d --- /dev/null +++ b/examples/scripts/sft_gemma3.py @@ -0,0 +1,62 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Train Gemma-3 on the Codeforces COTS dataset. + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py +""" + +from datasets import load_dataset +from transformers import AutoModelForImageTextToText + +from trl import SFTConfig, SFTTrainer + + +def main(): + # Load dataset + train_dataset = load_dataset("open-r1/codeforces-cots", split="train") + train_dataset = train_dataset.remove_columns("prompt") + + # Load model + model_id = "google/gemma-3-12b-it" + model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") + + # Train model + training_args = SFTConfig( + output_dir=f"{model_id}-codeforces-SFT", + logging_steps=10, + bf16=True, + use_liger_kernel=True, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + max_length=8192, + per_device_train_batch_size=1, + gradient_accumulation_steps=8, + dataset_num_proc=32, + num_train_epochs=1, + ) + trainer = SFTTrainer( + args=training_args, + model=model, + train_dataset=train_dataset, + ) + trainer.train() + + # Push to hub + trainer.push_to_hub(dataset_name="open-r1/codeforces-cots") + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/sft_video_llm.py b/examples/scripts/sft_video_llm.py new file mode 100644 index 0000000000000000000000000000000000000000..ca46ae64b4e87f86c9d1b17b8f29ff1a7a6a821d --- /dev/null +++ b/examples/scripts/sft_video_llm.py @@ -0,0 +1,253 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage: +accelerate launch \ + --config_file=deepspeed_zero2.yaml \ + sft_video_llm.py \ + --dataset_name=mfarre/simplevideoshorts \ + --video_cache_dir="/optional/path/to/cache/" \ + --model_name_or_path=Qwen/Qwen2-VL-7B-Instruct \ + --per_device_train_batch_size=1 \ + --output_dir=video-llm-output \ + --bf16=True \ + --tf32=True \ + --gradient_accumulation_steps=4 \ + --num_train_epochs=4 \ + --optim="adamw_torch_fused" \ + --logging_steps=1 \ + --log_level="debug" \ + --log_level_replica="debug" \ + --save_strategy="steps" \ + --save_steps=300 \ + --learning_rate=8e-5 \ + --max_grad_norm=0.3 \ + --warmup_ratio=0.1 \ + --lr_scheduler_type="cosine" \ + --report_to="wandb" \ + --push_to_hub=False \ + --torch_dtype=bfloat16 \ + --gradient_checkpointing=True +""" + +import json +import os +import random +from dataclasses import dataclass, field +from typing import Any + +import requests +import torch +import wandb +from datasets import load_dataset +from peft import LoraConfig +from qwen_vl_utils import process_vision_info +from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor + +from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map + + +def download_video(url: str, cache_dir: str) -> str: + """Download video if not already present locally.""" + os.makedirs(cache_dir, exist_ok=True) # Create cache dir if it doesn't exist + filename = url.split("/")[-1] + local_path = os.path.join(cache_dir, filename) + + if os.path.exists(local_path): + return local_path + + try: + with requests.get(url, stream=True) as r: + r.raise_for_status() + with open(local_path, "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + return local_path + except requests.RequestException as e: + raise Exception(f"Failed to download video: {e}") from e + + +def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]: + """Prepare dataset example for training.""" + video_url = example["video_url"] + timecoded_cc = example["timecoded_cc"] + qa_pairs = json.loads(example["qa"]) + + system_message = "You are an expert in movie narrative analysis." + base_prompt = f"""Analyze the video and consider the following timecoded subtitles: + +{timecoded_cc} + +Based on this information, please answer the following questions:""" + + selected_qa = random.sample(qa_pairs, 1)[0] + + messages = [ + {"role": "system", "content": [{"type": "text", "text": system_message}]}, + { + "role": "user", + "content": [ + {"type": "video", "video": download_video(video_url, cache_dir), "max_pixels": 360 * 420, "fps": 1.0}, + {"type": "text", "text": f"{base_prompt}\n\nQuestion: {selected_qa['question']}"}, + ], + }, + {"role": "assistant", "content": [{"type": "text", "text": selected_qa["answer"]}]}, + ] + + return {"messages": messages} + + +def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + """Collate batch of examples for training.""" + texts = [] + video_inputs = [] + + for i, example in enumerate(examples): + try: + video_path = next( + content["video"] + for message in example["messages"] + for content in message["content"] + if content.get("type") == "video" + ) + print(f"Processing video: {os.path.basename(video_path)}") + + texts.append(processor.apply_chat_template(example["messages"], tokenize=False)) + video_input = process_vision_info(example["messages"])[1][0] + video_inputs.append(video_input) + except Exception as e: + raise ValueError(f"Failed to process example {i}: {e}") from e + + inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True) + + labels = inputs["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + + # Handle visual tokens based on processor type + visual_tokens = ( + [151652, 151653, 151656] + if isinstance(processor, Qwen2VLProcessor) + else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)] + ) + + for visual_token_id in visual_tokens: + labels[labels == visual_token_id] = -100 + + inputs["labels"] = labels + return inputs + + +@dataclass +class CustomScriptArguments(ScriptArguments): + r""" + Arguments for the script. + + Args: + video_cache_dir (`str`, *optional*, defaults to `"/tmp/videos/"`): + Video cache directory. + """ + + video_cache_dir: str = field(default="/tmp/videos/", metadata={"help": "Video cache directory."}) + + +if __name__ == "__main__": + # Parse arguments + parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + # Configure training args + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + # Load dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train") + + # Setup model + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + + # Quantization configuration for 4-bit training + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + # Model initialization + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map(), + quantization_config=bnb_config, + ) + + model = AutoModelForVision2Seq.from_pretrained(model_args.model_name_or_path, **model_kwargs) + + peft_config = LoraConfig( + task_type="CAUSAL_LM", + r=16, + lora_alpha=16, + lora_dropout=0.1, + bias="none", + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + + # Configure model modules for gradients + if training_args.gradient_checkpointing: + model.gradient_checkpointing_enable() + model.config.use_reentrant = False + model.enable_input_require_grads() + + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + # Prepare dataset + prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset] + + # Initialize wandb if specified + if training_args.report_to == "wandb": + wandb.init(project="video-llm-training") + + # Initialize trainer + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=prepared_dataset, + data_collator=collate_fn, + peft_config=peft_config, + tokenizer=processor.tokenizer, + ) + + # Train model + trainer.train() + + # Save final model + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id) + + # Cleanup + del model + del trainer + torch.cuda.empty_cache() + wandb.finish() diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..ba154c653408a4d3c06a62c857a8d65994f6d5f0 --- /dev/null +++ b/examples/scripts/sft_vlm.py @@ -0,0 +1,132 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +pip install pillow + +# Tested on 8x H100 GPUs +accelerate launch + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path llava-hf/llava-1.5-7b-hf \ + --per_device_train_batch_size 8 \ + --gradient_accumulation_steps 8 \ + --output_dir sft-llava-1.5-7b-hf \ + --bf16 \ + --torch_dtype bfloat16 \ + --gradient_checkpointing + +For LLaVA-NeXT, use: (requires transformers>=4.45) + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf + +For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1) + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + ################ + # Create a data collator to encode text and image pairs + ################ + def collate_fn(examples): + # Get the texts and images, and apply the chat template + texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] + images = [example["images"] for example in examples] + if isinstance(model, LlavaForConditionalGeneration): + # LLava1.5 does not support multiple images + images = [image[0] for image in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=processor.tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id) diff --git a/examples/scripts/sft_vlm_gemma3.py b/examples/scripts/sft_vlm_gemma3.py new file mode 100644 index 0000000000000000000000000000000000000000..cd1caf1c3b622e94c36fe764fde9e1177996debd --- /dev/null +++ b/examples/scripts/sft_vlm_gemma3.py @@ -0,0 +1,223 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image). + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm_gemma3.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path google/gemma-3-4b-it \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \ + --bf16 \ + --torch_dtype bfloat16 \ + --use_peft \ + --lora_target_modules all-linear \ + --attn_implementation eager + +Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image). + +accelerate launch \ + --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/sft_vlm_gemma3.py \ + --dataset_name FanqingM/MMIU-Benchmark \ + --dataset_train_split test \ + --model_name_or_path google/gemma-3-4b-it \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \ + --bf16 \ + --torch_dtype bfloat16 \ + --use_peft \ + --lora_target_modules all-linear + --attn_implementation eager +""" + +import io +import os +import zipfile + +import torch +from datasets import DatasetDict, load_dataset +from huggingface_hub import hf_hub_download, list_repo_files +from PIL import Image +from transformers import AutoModelForImageTextToText, AutoProcessor + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +# For multi-image example +def process_vision_info(messages: list[dict]) -> list[Image.Image]: + image_inputs = [] + for msg in messages: + content = msg.get("content", []) + if not isinstance(content, list): + content = [content] + + for element in content: + if isinstance(element, dict) and ("image" in element or element.get("type") == "image"): + if "image" in element: + image = element["image"] + else: + image = element + if image is not None: + image = Image.open(io.BytesIO(image["bytes"])) + image_inputs.append(image.convert("RGB")) + return image_inputs + + +def format_data(samples: dict[str, any]) -> dict[str, list]: + formatted_samples = {"messages": []} + for cont in range(len(samples["question"])): + images = [] + for img_path in samples["input_image_path"][cont]: + try: + with open(img_path, "rb") as f: + img_bytes = f.read() + image = Image.open(io.BytesIO(img_bytes)).convert("RGB") + images.append({"type": "image", "image": image}) + except Exception as e: + print(f"Error processing image {img_path}: {e}") + continue + + formatted_samples["messages"].append( + [ + {"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]}, + {"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]}, + {"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]}, + ] + ) + return formatted_samples + + +# For multi-image example +def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict: + all_files = list_repo_files(dataset_name, repo_type="dataset") + zip_files = [f for f in all_files if f.endswith(".zip")] + + for zip_filename in zip_files: + zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset") + extract_folder = zip_filename.replace(".zip", "") + os.makedirs(extract_folder, exist_ok=True) + + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(extract_folder) + + dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16) + return dataset + + +def main(): + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + processor.tokenizer.padding_side = "right" + + model = AutoModelForImageTextToText.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + def collate_fn(examples): + texts = [ + processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() + for example in examples + ] + if "images" in examples[0]: # single-image + images = [[img.convert("RGB") for img in example["images"]] for example in examples] + else: # multi-image + images = [process_vision_info(example["messages"]) for example in examples] + + # Tokenize the texts and process the images + batch = processor( + text=texts, images=images, return_tensors="pt", padding=True + ) # Encode texts and images into tensors + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() # Clone input IDs for labels + # Mask image tokens + image_token_id = [ + processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"]) + ] + # Mask tokens for not being used in the loss computation + labels[labels == processor.tokenizer.pad_token_id] = -100 + labels[labels == image_token_id] = -100 + labels[labels == 262144] = -100 + + batch["labels"] = labels + return batch # Return the prepared batch + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + if script_args.dataset_name == "FanqingM/MMIU-Benchmark": + dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=processor.tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id) + + +if __name__ == "__main__": + main() diff --git a/examples/scripts/sft_vlm_smol_vlm.py b/examples/scripts/sft_vlm_smol_vlm.py new file mode 100644 index 0000000000000000000000000000000000000000..6ad5884d5f27a06ad49ba9753a3d05a31a06398a --- /dev/null +++ b/examples/scripts/sft_vlm_smol_vlm.py @@ -0,0 +1,144 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +pip install pillow + +# Tested on 8x H100 GPUs +accelerate launch + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + sft_vlm_smol_vlm.py \ + --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \ + --model_name_or_path HuggingFaceTB/SmolVLM-Instruct \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --output_dir sft-smol-vlm-hf \ + --bf16 \ + --torch_dtype bfloat16 \ + --gradient_checkpointing \ + --use_peft \ + --lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj + +For LLaVA-NeXT, use: (requires transformers>=4.45) + --model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf + +For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1) + --model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct +""" + +import torch +from datasets import load_dataset +from transformers import ( + AutoModelForVision2Seq, + AutoProcessor, + Idefics3ForConditionalGeneration, + LlavaForConditionalGeneration, +) + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + training_args.remove_unused_columns = False + training_args.dataset_kwargs = {"skip_prepare_dataset": True} + + ################ + # Model, Tokenizer & Processor + ################ + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + model = AutoModelForVision2Seq.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + ################ + # Create a data collator to encode text and image pairs + ################ + def collate_fn(examples): + # Get the texts and images, and apply the chat template + texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] + images = [example["images"] for example in examples] + if isinstance(model, LlavaForConditionalGeneration): + # LLava1.5 does not support multiple images + images = [image[0] for image in images] + + # Tokenize the texts and process the images + batch = processor(text=texts, images=images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 # + # Ignore the image token index in the loss computation (model specific) + if isinstance(model, Idefics3ForConditionalGeneration): + image_token_id = processor.tokenizer.additional_special_tokens_ids[ + processor.tokenizer.additional_special_tokens.index("") + ] + else: + image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token) + labels[labels == image_token_id] = -100 + batch["labels"] = labels + + return batch + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + data_collator=collate_fn, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=processor.tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + if trainer.accelerator.is_main_process: + processor.push_to_hub(training_args.hub_model_id) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbd13e99202453e33a7d9041348b2a8d7401769 --- /dev/null +++ b/examples/scripts/xpo.py @@ -0,0 +1,130 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +python examples/scripts/xpo.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-xpo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig + +from trl import ( + HfPairwiseJudge, + LogCompletionsCallback, + ModelConfig, + OpenAIPairwiseJudge, + PairRMJudge, + ScriptArguments, + TrlParser, + XPOConfig, + XPOTrainer, + get_kbit_device_map, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +JUDGES = {"pair_rm": PairRMJudge, "openai": OpenAIPairwiseJudge, "hf": HfPairwiseJudge} + + +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, XPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + training_args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + + if training_args.reward_model_path is not None: + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, + num_labels=1, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + else: + reward_model = None + + if training_args.judge is not None: + judge_cls = JUDGES[training_args.judge] + judge = judge_cls() + else: + judge = None + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + trainer = XPOTrainer( + model=model, + ref_model=ref_model, + reward_model=reward_model, + judge=judge, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + ) + + if training_args.eval_strategy != "no": + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..1eb32c45da87d43802b0160ab06f3291ba53cf66 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[project] +name = "trl" +version = "0.0.0" +requires-python = ">=3.13" + +dependencies = [ + "accelerate>=1.7.0", + "datasets>=3.6.0", + "deepspeed>=0.17.1", + "peft>=0.15.2", + "transformers>=4.52.4", +] + +[tool.ruff] +target-version = "py39" +line-length = 119 + +[tool.ruff.lint] +ignore = [ + "B028", # warning without explicit stacklevel + "C408", # dict() calls (stylistic) + "C901", # function complexity + "E501", +] +extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"] + +[tool.ruff.lint.per-file-ignores] +# Allow prints in auxiliary scripts +"examples/**.py" = ["T201"] +"scripts/**.py" = ["T201"] +# Ignore import violations in all `__init__.py` files. +"__init__.py" = ["F401"] + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["trl"] + +[tool.pytest.ini_options] +markers = [ + "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "low-priority: marks tests as low priority (deselect with '-m \"not low-priority\"')", +] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0717481d0ea7c47fa221a273f9345d16e1bc979d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +accelerate>=1.4.0 +datasets>=3.0.0 +transformers>=4.51.0 \ No newline at end of file diff --git a/scripts/add_copyrights.py b/scripts/add_copyrights.py new file mode 100644 index 0000000000000000000000000000000000000000..29466f503a2dadf7f5bf25e2fa2e0777db2bf359 --- /dev/null +++ b/scripts/add_copyrights.py @@ -0,0 +1,92 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import sys +from datetime import datetime + + +COPYRIGHT_HEADER = f"""# Copyright 2020-{datetime.now().year} The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + + +def get_tracked_python_files(): + """Get a list of all tracked Python files using git.""" + try: + # Get the list of all tracked files from Git + result = subprocess.run(["git", "ls-files"], stdout=subprocess.PIPE, text=True, check=True) + # Split the result by lines to get individual file paths + files = result.stdout.splitlines() + # Filter only Python files + py_files = [f for f in files if f.endswith(".py")] + return py_files + except subprocess.CalledProcessError as e: + print(f"Error fetching tracked files: {e}") + return [] + + +def check_and_add_copyright(file_path): + """Check if the file contains a copyright notice, and add it if missing.""" + if not os.path.isfile(file_path): + print(f"[SKIP] {file_path} does not exist.") + return + + with open(file_path, encoding="utf-8") as f: + content = f.readlines() + + # Check if the exact copyright header exists + if "".join(content).startswith(COPYRIGHT_HEADER): + return True + + # If no copyright notice was found, prepend the header + print(f"[MODIFY] Adding copyright to {file_path}.") + with open(file_path, "w", encoding="utf-8") as f: + # Write the copyright header followed by the original content + f.write(COPYRIGHT_HEADER + "\n" + "".join(content)) + return False + + +def main(): + """Main function to check and add copyright for all tracked Python files.""" + py_files = get_tracked_python_files() + if not py_files: + print("No Python files are tracked in the repository.") + return + + print(f"Checking {len(py_files)} Python files for copyright notice...") + + have_copyright = [check_and_add_copyright(file_path) for file_path in py_files] + if not all(have_copyright): + print("❌ Some files were missing the required copyright and have been updated.") + sys.exit(1) + else: + print("✅ All files have the required copyright.") + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py new file mode 100644 index 0000000000000000000000000000000000000000..40d141ea9b18a95b702bd34bcf0d2532212bb3f5 --- /dev/null +++ b/scripts/generate_tiny_models.py @@ -0,0 +1,251 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script generates tiny models used in the TRL library for unit tests. It pushes them to the Hub under the +# `trl-internal-testing` organization. +# This script is meant to be run when adding new tiny model to the TRL library. + +from huggingface_hub import HfApi, ModelCard +from transformers import ( + AutoProcessor, + AutoTokenizer, + BartConfig, + BartModel, + BloomConfig, + BloomForCausalLM, + CLIPVisionConfig, + CohereConfig, + CohereForCausalLM, + DbrxConfig, + DbrxForCausalLM, + DeepseekV3Config, + DeepseekV3ForCausalLM, + FalconMambaConfig, + FalconMambaForCausalLM, + Gemma2Config, + Gemma2ForCausalLM, + GemmaConfig, + GemmaForCausalLM, + GPT2Config, + GPT2LMHeadModel, + GPTNeoXConfig, + GPTNeoXForCausalLM, + Idefics2Config, + Idefics2ForConditionalGeneration, + LlamaConfig, + LlamaForCausalLM, + LlamaForSequenceClassification, + LlavaConfig, + LlavaForConditionalGeneration, + LlavaNextConfig, + LlavaNextForConditionalGeneration, + MistralConfig, + MistralForCausalLM, + OPTConfig, + OPTForCausalLM, + PaliGemmaConfig, + PaliGemmaForConditionalGeneration, + Phi3Config, + Phi3ForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, + Qwen2ForSequenceClassification, + Qwen3Config, + Qwen3ForCausalLM, + Qwen3ForSequenceClassification, + SiglipVisionConfig, + T5Config, + T5ForConditionalGeneration, +) +from transformers.models.idefics2.configuration_idefics2 import Idefics2VisionConfig + + +ORGANIZATION = "trl-internal-testing" + +MODEL_CARD = """ +--- +library_name: transformers +tags: [trl] +--- + +# Tiny {model_class_name} + +This is a minimal model built for unit tests in the [TRL](https://github.com/huggingface/trl) library. +""" + + +api = HfApi() + + +def push_to_hub(model, tokenizer, prefix=None, suffix=None): + model_class_name = model.__class__.__name__ + content = MODEL_CARD.format(model_class_name=model_class_name) + model_card = ModelCard(content) + if prefix is not None: + model_class_name = f"{prefix}-{model_class_name}" + repo_id = f"{ORGANIZATION}/{model_class_name}" + if suffix is not None: + repo_id += f"-{suffix}" + + if api.repo_exists(repo_id): + print(f"Model {repo_id} already exists, skipping") + else: + model.push_to_hub(repo_id) + tokenizer.push_to_hub(repo_id) + model_card.push_to_hub(repo_id) + + +# Decoder models +for model_id, config_class, model_class, suffix in [ + ("bigscience/bloomz-560m", BloomConfig, BloomForCausalLM, None), + ("CohereForAI/aya-expanse-8b", CohereConfig, CohereForCausalLM, None), + ("databricks/dbrx-instruct", DbrxConfig, DbrxForCausalLM, None), + ("deepseek-ai/DeepSeek-R1", DeepseekV3Config, DeepseekV3ForCausalLM, None), + # It's important to have R1-0528 as it doesn't have the same chat template + ("deepseek-ai/DeepSeek-R1-0528", DeepseekV3Config, DeepseekV3ForCausalLM, "0528"), + ("tiiuae/falcon-7b-instruct", FalconMambaConfig, FalconMambaForCausalLM, None), + ("google/gemma-2-2b-it", Gemma2Config, Gemma2ForCausalLM, None), + ("google/gemma-7b-it", GemmaConfig, GemmaForCausalLM, None), + ("openai-community/gpt2", GPT2Config, GPT2LMHeadModel, None), + ("EleutherAI/pythia-14m", GPTNeoXConfig, GPTNeoXForCausalLM, None), + ("meta-llama/Meta-Llama-3-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3"), + ("meta-llama/Llama-3.1-8B-Instruct", LlamaConfig, LlamaForCausalLM, "3.1"), + ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForCausalLM, "3.2"), + ("mistralai/Mistral-7B-Instruct-v0.1", MistralConfig, MistralForCausalLM, "0.1"), + ("mistralai/Mistral-7B-Instruct-v0.2", MistralConfig, MistralForCausalLM, "0.2"), + ("facebook/opt-1.3b", OPTConfig, OPTForCausalLM, None), + ("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None), + ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"), + ("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"), + ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForCausalLM, None), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = config_class( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + hidden_size=8, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, + ) + model = model_class(config) + push_to_hub(model, tokenizer, "tiny", suffix) + + +# Two slightly bigger models, required for vLLM testing +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct") +config = Qwen2Config( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, +) +model = Qwen2ForCausalLM(config) +push_to_hub(model, tokenizer, "small", "2.5") + +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B") +config = Qwen3Config( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, +) +model = Qwen3ForCausalLM(config) +push_to_hub(model, tokenizer, "small") + +# Reward models +for model_id, config_class, model_class, suffix in [ + ("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"), + ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"), + ("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = config_class( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + hidden_size=8, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, + num_labels=1, + ) + model = model_class(config) + push_to_hub(model, tokenizer, "tiny", suffix) + + +# Encoder-decoder models +for model_id, config_class, model_class, suffix in [ + ("google/flan-t5-small", T5Config, T5ForConditionalGeneration, None), + ("facebook/bart-base", BartConfig, BartModel, None), +]: + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = config_class( + vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()), + d_model=16, + encoder_layers=2, + decoder_layers=2, + d_kv=2, + d_ff=64, + num_layers=6, + num_heads=8, + decoder_start_token_id=0, + is_encoder_decoder=True, + ) + model = model_class(config) + push_to_hub(model, tokenizer, "tiny", suffix) + + +# Vision Language Models +# fmt: off +for model_id, config_class, text_config_class, vision_config_class, model_class in [ + ("HuggingFaceM4/idefics2-8b", Idefics2Config, MistralConfig, Idefics2VisionConfig, Idefics2ForConditionalGeneration), + ("llava-hf/llava-1.5-7b-hf", LlavaConfig, LlamaConfig, CLIPVisionConfig, LlavaForConditionalGeneration), + ("llava-hf/llava-v1.6-mistral-7b-hf", LlavaNextConfig, MistralConfig, CLIPVisionConfig, LlavaNextForConditionalGeneration), + ("google/paligemma-3b-pt-224", PaliGemmaConfig, GemmaConfig, SiglipVisionConfig, PaliGemmaForConditionalGeneration), +]: +# fmt: on + processor = AutoProcessor.from_pretrained(model_id) + kwargs = {} + if config_class == PaliGemmaConfig: + kwargs["projection_dim"] = 8 + vision_kwargs = {} + if vision_config_class in [CLIPVisionConfig, SiglipVisionConfig]: + vision_kwargs["projection_dim"] = 8 + if vision_config_class == CLIPVisionConfig: + vision_kwargs["image_size"] = 336 + vision_kwargs["patch_size"] = 14 + config = config_class( + text_config=text_config_class( + vocab_size=processor.tokenizer.vocab_size + len(processor.tokenizer.added_tokens_encoder), + hidden_size=8, + num_attention_heads=4, + num_key_value_heads=2, + num_hidden_layers=2, + intermediate_size=32, + ), + vision_config=vision_config_class( + hidden_size=8, + num_attention_heads=4, + num_hidden_layers=2, + intermediate_size=32, + **vision_kwargs, + ), + **kwargs, + ) + model = model_class(config) + push_to_hub(model, processor, "tiny") diff --git a/scripts/generate_zen_dataset.py b/scripts/generate_zen_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..730643e4cf63e5a4da065daf6989edfd9896b9ac --- /dev/null +++ b/scripts/generate_zen_dataset.py @@ -0,0 +1,660 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from datasets import Dataset +from transformers import HfArgumentParser + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + test_size (`float`, *optional*, defaults to `0.1`): + Fraction of the dataset to include in the test split. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the dataset to the Hugging Face Hub. + repo_id (`str`, *optional*, defaults to `"trl-internal-testing/zen"`): + Hugging Face repository ID to push the dataset to. + """ + + test_size: float = field( + default=0.1, + metadata={"help": "Fraction of the dataset to include in the test split."}, + ) + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, + ) + repo_id: str = field( + default="trl-internal-testing/zen", + metadata={"help": "Hugging Face repository ID to push the dataset to."}, + ) + + +def main(test_size, push_to_hub, repo_id): + # fmt: off + standard_language_modeling_dataset = Dataset.from_dict({ + "text": [ + "Beautiful is better than ugly.", + "Explicit is better than implicit.", + "Simple is better than complex.", + "Complex is better than complicated.", + "Flat is better than nested.", + "Sparse is better than dense.", + "Readability counts.", + "Special cases aren't special enough to break the rules.", + "Although practicality beats purity.", + "Errors should never pass silently.", + "Unless explicitly silenced.", + "In the face of ambiguity, refuse the temptation to guess.", + "There should be one-- and preferably only one --obvious way to do it.", + "Although that way may not be obvious at first unless you're Dutch.", + "Now is better than never.", + "Although never is often better than *right* now.", + "If the implementation is hard to explain, it's a bad idea.", + "If the implementation is easy to explain, it may be a good idea.", + "Namespaces are one honking great idea -- let's do more of those!", + ], + }) + standard_language_modeling_dataset = standard_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_language_modeling_dataset.push_to_hub(repo_id, config_name="standard_language_modeling") + + standard_prompt_only_dataset = Dataset.from_dict({ + "prompt": [ + "Beautiful is better than", + "Explicit is", + "Simple is better", + "Complex", + "Flat is better than", + "Sparse is better", + "Readability", + "Special cases aren't special", + "Although practicality beats", + "Errors should never", + "Unless explicitly", + "In the face of ambiguity, refuse", + "There should be one-- and preferably", + "Although that way may not be obvious at first unless you're", + "Now is", + "Although never is often", + "If the implementation is hard to explain,", + "If the implementation is easy", + "Namespaces are one honking great", + ], + }) + standard_prompt_only_dataset = standard_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_prompt_only_dataset.push_to_hub(repo_id, config_name="standard_prompt_only") + + standard_prompt_completion_dataset = Dataset.from_dict({ + "prompt": [ + "Beautiful is better than", + "Explicit is", + "Simple is better", + "Complex", + "Flat is better than", + "Sparse is better", + "Readability", + "Special cases aren't special", + "Although practicality beats", + "Errors should never", + "Unless explicitly", + "In the face of ambiguity, refuse", + "There should be one-- and preferably", + "Although that way may not be obvious at first unless you're", + "Now is", + "Although never is often", + "If the implementation is hard to explain,", + "If the implementation is easy", + "Namespaces are one honking great", + ], + "completion": [ + " ugly.", + " better than implicit.", + " than complex.", + " is better than complicated.", + " nested.", + " than dense.", + " counts.", + " enough to break the rules.", + " purity.", + " pass silently.", + " silenced.", + " the temptation to guess.", + " only one --obvious way to do it.", + " Dutch.", + " better than never.", + " better than *right* now.", + " it's a bad idea.", + " to explain, it may be a good idea.", + " idea -- let's do more of those!", + ], + }) + standard_prompt_completion_dataset = standard_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_prompt_completion_dataset.push_to_hub(repo_id, config_name="standard_prompt_completion") + + standard_preference_dataset = Dataset.from_dict({ + "prompt": [ + "Beautiful is better than", + "Explicit is", + "Simple is better", + "Complex", + "Flat is better than", + "Sparse is better", + "Readability", + "Special cases aren't special", + "Although practicality beats", + "Errors should never", + "Unless explicitly", + "In the face of ambiguity, refuse", + "There should be one-- and preferably", + "Although that way may not be obvious at first unless you're", + "Now is", + "Although never is often", + "If the implementation is hard to explain,", + "If the implementation is easy", + "Namespaces are one honking great", + ], + "chosen": [ + " ugly.", + " better than implicit.", + " than complex.", + " is better than complicated.", + " nested.", + " than dense.", + " counts.", + " enough to break the rules.", + " purity.", + " pass silently.", + " silenced.", + " the temptation to guess.", + " only one --obvious way to do it.", + " Dutch.", + " better than never.", + " better than *right* now.", + " it's a bad idea.", + " to explain, it may be a good idea.", + " idea -- let's do more of those!", + ], + "rejected": [ + " the moon.", + " worse than nothing.", + " than a long vacation.", + " is always the answer.", + " chocolate.", + " without any context.", + " is optional.", + " enough to become unicorns.", + " reality.", + " pass their driving test.", + " forgotten.", + " the opportunity to laugh.", + " two or more confusing methods.", + " a time traveler.", + " never better.", + " not even a possibility.", + " it's clearly the best choice.", + " it's probably magic.", + " watermelon -- let's plant some!", + ], + }) + standard_preference_dataset = standard_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_preference_dataset.push_to_hub(repo_id, config_name="standard_preference") + + standard_implicit_prompt_preference_dataset = Dataset.from_dict({ + "chosen": [ + "Beautiful is better than ugly.", + "Explicit is better than implicit.", + "Simple is better than complex.", + "Complex is better than complicated.", + "Flat is better than nested.", + "Sparse is better than dense.", + "Readability counts.", + "Special cases aren't special enough to break the rules.", + "Although practicality beats purity.", + "Errors should never pass silently.", + "Unless explicitly silenced.", + "In the face of ambiguity, refuse the temptation to guess.", + "There should be one-- and preferably only one --obvious way to do it.", + "Although that way may not be obvious at first unless you're Dutch.", + "Now is better than never.", + "Although never is often better than *right* now.", + "If the implementation is hard to explain, it's a bad idea.", + "If the implementation is easy to explain, it may be a good idea.", + "Namespaces are one honking great idea -- let's do more of those!", + ], + "rejected": [ + "Beautiful is better than the moon.", + "Explicit is worse than nothing.", + "Simple is better than a long vacation.", + "Complex is always the answer.", + "Flat is better than chocolate.", + "Sparse is better without any context.", + "Readability is optional.", + "Special cases aren't special enough to become unicorns.", + "Although practicality beats reality.", + "Errors should never pass their driving test.", + "Unless explicitly forgotten.", + "In the face of ambiguity, refuse the opportunity to laugh.", + "There should be one-- and preferably two or more confusing methods.", + "Although that way may not be obvious at first unless you're a time traveler.", + "Now is never better.", + "Although never is often not even a possibility.", + "If the implementation is hard to explain, it's clearly the best choice.", + "If the implementation is easy it's probably magic.", + "Namespaces are one honking great watermelon -- let's plant some!", + ], + }) + standard_implicit_prompt_preference_dataset = standard_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="standard_implicit_prompt_preference") + + standard_unpaired_preference_dataset = Dataset.from_dict({ + "prompt": [ + "Beautiful is better than", + "Explicit is", + "Simple is better", + "Complex", + "Flat is better than", + "Sparse is better", + "Readability", + "Special cases aren't special", + "Although practicality beats", + "Errors should never", + "Unless explicitly", + "In the face of ambiguity, refuse", + "There should be one-- and preferably", + "Although that way may not be obvious at first unless you're", + "Now is", + "Although never is often", + "If the implementation is hard to explain,", + "If the implementation is easy", + "Namespaces are one honking great", + ], + "completion": [ + " ugly.", + " worse than nothing.", + " than a long vacation.", + " is better than complicated.", + " nested.", + " without any context.", + " counts.", + " enough to become unicorns.", + " purity.", + " pass silently.", + " forgotten.", + " the temptation to guess.", + " only one --obvious way to do it.", + " a time traveler.", + " better than never.", + " not even a possibility.", + " it's a bad idea.", + " it's probably magic.", + " watermelon -- let's plant some!", + ], + "label": [True, False, False, True, True, False, True, False, True, True, False, True, True, False, True, False, True, False, False], + }) + standard_unpaired_preference_dataset = standard_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference") + + standard_stepwise_supervision_dataset = Dataset.from_dict({ + "prompt": [ + "Beautiful is better than", + "Explicit is better than", + "Simple is better than", + "Complex is better than", + "Flat is better than", + "Sparse is better than", + "Readability counts", + "Special cases aren't special enough", + "Although practicality beats", + "Errors should never pass", + "In the face of ambiguity, refuse", + "There should be one-- and preferably only one --", + "Although that way may not be", + "Now is better than", + "Never is often better than", + "If the implementation is hard to explain, it's", + "If the implementation is easy to explain, it", + "Namespaces are one", + "Although practicality sometimes beats purity,", + ], + "completions":[ + [", let me think...", " ugly."], + [", of course,", " implicit.", " because clarity matters."], + ["... let's keep it basic,", " complex."], + [" when needed,", " complicated."], + [" in terms of structure,", " nested."], + ["... especially for readability."], + [" especially when others read it."], + [", unless...", " they follow the rules."], + [" some theoretical elegance,", " purity."], + [" silently,", " unless explicitly silenced."], + [" the temptation to guess."], + [" way to do it,"," but sometimes it's not obvious.", " especially when there's more than one possibility."], + [" clear at first,", " it will eventually emerge."], + [" later."], + [" problematic fixes."], + [" likely because it's too complicated."], + [" might be a good design."], + [" of those great ideas,", " that solve many problems."], + [" the code should still aim for balance."], + ], + "labels": [ + [False, True], + [False, True, False], + [False, True], + [True, True], + [True, False], + [True], + [False], + [True, False], + [False, False], + [False, False], + [True], + [True, True, False], + [True, True], + [False], + [True], [False], + [False], + [True, True], + [False] + ] + }) + standard_stepwise_supervision_dataset = standard_stepwise_supervision_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + standard_stepwise_supervision_dataset.push_to_hub(repo_id, config_name="standard_stepwise_supervision") + + conversational_language_modeling_dataset = Dataset.from_dict({ + "messages": [ + [{"role": "user", "content": "What is better than ugly?"}, {"role": "assistant", "content": "Beautiful."},], + [{"role": "user", "content": "What is better than implicit?"}, {"role": "assistant", "content": "Explicit."}], + [{"role": "user", "content": "What is better than complex?"}, {"role": "assistant", "content": "Simple."}], + [{"role": "user", "content": "What is better than complicated?"}, {"role": "assistant", "content": "Complex."}], + [{"role": "user", "content": "What is better than nested?"}, {"role": "assistant", "content": "Flat."}], + [{"role": "user", "content": "What is better than dense?"}, {"role": "assistant", "content": "Sparse."}], + [{"role": "user", "content": "What counts?"}, {"role": "assistant", "content": "Readability."}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}, {"role": "assistant", "content": "No, special cases aren't special enough to break the rules."}], + [{"role": "user", "content": "What beats purity?"}, {"role": "assistant", "content": "Practicality."}], + [{"role": "user", "content": "What should never pass silently?"}, {"role": "assistant", "content": "Errors."}], + [{"role": "user", "content": "When can errors pass silently?"}, {"role": "assistant", "content": "When explicitly silenced."}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}, {"role": "assistant", "content": "Refuse the temptation to guess."}], + [{"role": "user", "content": "How many ways should there be to do it?"}, {"role": "assistant", "content": "One, and preferably only one."}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}, {"role": "assistant", "content": "Dutch."}], + [{"role": "user", "content": "What is better than never?"}, {"role": "assistant", "content": "Now is better than never."}], + [{"role": "user", "content": "Is never better than *right* now?"}, {"role": "assistant", "content": "Yes, often."}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}, {"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it may be a good idea."}], + [{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}], + ], + }) + conversational_language_modeling_dataset = conversational_language_modeling_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_language_modeling_dataset.push_to_hub(repo_id, config_name="conversational_language_modeling") + + conversational_prompt_only_dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What is better than ugly?"}], + [{"role": "user", "content": "What is better than implicit?"}], + [{"role": "user", "content": "What is better than complex?"}], + [{"role": "user", "content": "What is better than complicated?"}], + [{"role": "user", "content": "What is better than nested?"}], + [{"role": "user", "content": "What is better than dense?"}], + [{"role": "user", "content": "What counts?"}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}], + [{"role": "user", "content": "What beats purity?"}], + [{"role": "user", "content": "What should never pass silently?"}], + [{"role": "user", "content": "When can errors pass silently?"}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}], + [{"role": "user", "content": "How many ways should there be to do it?"}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}], + [{"role": "user", "content": "What is better than never?"}], + [{"role": "user", "content": "Is never better than *right* now?"}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}], + [{"role": "user", "content": "Any great ideas?"}], + ], + }) + conversational_prompt_only_dataset = conversational_prompt_only_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_prompt_only_dataset.push_to_hub(repo_id, config_name="conversational_prompt_only") + + conversational_prompt_completion_dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What is better than ugly?"}], + [{"role": "user", "content": "What is better than implicit?"}], + [{"role": "user", "content": "What is better than complex?"}], + [{"role": "user", "content": "What is better than complicated?"}], + [{"role": "user", "content": "What is better than nested?"}], + [{"role": "user", "content": "What is better than dense?"}], + [{"role": "user", "content": "What counts?"}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}], + [{"role": "user", "content": "What beats purity?"}], + [{"role": "user", "content": "What should never pass silently?"}], + [{"role": "user", "content": "When can errors pass silently?"}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}], + [{"role": "user", "content": "How many ways should there be to do it?"}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}], + [{"role": "user", "content": "What is better than never?"}], + [{"role": "user", "content": "Is never better than *right* now?"}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}], + [{"role": "user", "content": "Any great ideas?"}], + ], + "completion": [ + [{"role": "assistant", "content": "Beautiful."}], + [{"role": "assistant", "content": "Explicit."}], + [{"role": "assistant", "content": "Simple."}], + [{"role": "assistant", "content": "Complex."}], + [{"role": "assistant", "content": "Flat."}], + [{"role": "assistant", "content": "Sparse."}], + [{"role": "assistant", "content": "Readability."}], + [{"role": "assistant", "content": "No, special cases aren't special enough to break the rules."}], + [{"role": "assistant", "content": "Practicality."}], + [{"role": "assistant", "content": "Errors."}], + [{"role": "assistant", "content": "When explicitly silenced."}], + [{"role": "assistant", "content": "Refuse the temptation to guess."}], + [{"role": "assistant", "content": "One, and preferably only one."}], + [{"role": "assistant", "content": "Dutch."}], + [{"role": "assistant", "content": "Now is better than never."}], + [{"role": "assistant", "content": "Yes, often."}], + [{"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "assistant", "content": "It means it may be a good idea."}], + [{"role": "assistant", "content": "Namespaces are one honking great idea."}], + ], + }) + conversational_prompt_completion_dataset = conversational_prompt_completion_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_prompt_completion_dataset.push_to_hub(repo_id, config_name="conversational_prompt_completion") + + conversational_preference_dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What is better than ugly?"}], + [{"role": "user", "content": "What is better than implicit?"}], + [{"role": "user", "content": "What is better than complex?"}], + [{"role": "user", "content": "What is better than complicated?"}], + [{"role": "user", "content": "What is better than nested?"}], + [{"role": "user", "content": "What is better than dense?"}], + [{"role": "user", "content": "What counts?"}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}], + [{"role": "user", "content": "What beats purity?"}], + [{"role": "user", "content": "What should never pass silently?"}], + [{"role": "user", "content": "When can errors pass silently?"}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}], + [{"role": "user", "content": "How many ways should there be to do it?"}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}], + [{"role": "user", "content": "What is better than never?"}], + [{"role": "user", "content": "Is never better than *right* now?"}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}], + [{"role": "user", "content": "Any great ideas?"}], + ], + "chosen": [ + [{"role": "assistant", "content": "Beautiful."}], + [{"role": "assistant", "content": "Explicit."}], + [{"role": "assistant", "content": "Simple."}], + [{"role": "assistant", "content": "Complex."}], + [{"role": "assistant", "content": "Flat."}], + [{"role": "assistant", "content": "Sparse."}], + [{"role": "assistant", "content": "Readability."}], + [{"role": "assistant", "content": "No, special cases aren't special enough to break the rules."}], + [{"role": "assistant", "content": "Practicality."}], + [{"role": "assistant", "content": "Errors."}], + [{"role": "assistant", "content": "When explicitly silenced."}], + [{"role": "assistant", "content": "Refuse the temptation to guess."}], + [{"role": "assistant", "content": "One, and preferably only one."}], + [{"role": "assistant", "content": "Dutch."}], + [{"role": "assistant", "content": "Now is better than never."}], + [{"role": "assistant", "content": "Yes, often."}], + [{"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "assistant", "content": "It means it may be a good idea."}], + [{"role": "assistant", "content": "Namespaces are one honking great idea."}], + ], + "rejected": [ + [{"role": "assistant", "content": "Acceptable."}], + [{"role": "assistant", "content": "Explained."}], + [{"role": "assistant", "content": "Very complex."}], + [{"role": "assistant", "content": "Very complicated."}], + [{"role": "assistant", "content": "Circular."}], + [{"role": "assistant", "content": "Heavy."}], + [{"role": "assistant", "content": "Looking complicated."}], + [{"role": "assistant", "content": "Yes, special cases are special enough to break the rules."}], + [{"role": "assistant", "content": "Nothing."}], + [{"role": "assistant", "content": "Warnings."}], + [{"role": "assistant", "content": "Never."}], + [{"role": "assistant", "content": "Give up."}], + [{"role": "assistant", "content": "As many as possible."}], + [{"role": "assistant", "content": "French."}], + [{"role": "assistant", "content": "Some day."}], + [{"role": "assistant", "content": "No, never."}], + [{"role": "assistant", "content": "It means it's a good idea."}], + [{"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "assistant", "content": "Recursion."}], + ], + }) + conversational_preference_dataset = conversational_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_preference_dataset.push_to_hub(repo_id, config_name="conversational_preference") + + conversational_implicit_prompt_preference_dataset = Dataset.from_dict({ + "chosen": [ + [{"role": "user", "content": "What is better than ugly?"}, {"role": "assistant", "content": "Beautiful."}], + [{"role": "user", "content": "What is better than implicit?"}, {"role": "assistant", "content": "Explicit."}], + [{"role": "user", "content": "What is better than complex?"}, {"role": "assistant", "content": "Simple."}], + [{"role": "user", "content": "What is better than complicated?"}, {"role": "assistant", "content": "Complex."}], + [{"role": "user", "content": "What is better than nested?"}, {"role": "assistant", "content": "Flat."}], + [{"role": "user", "content": "What is better than dense?"}, {"role": "assistant", "content": "Sparse."}], + [{"role": "user", "content": "What counts?"}, {"role": "assistant", "content": "Readability."}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}, {"role": "assistant", "content": "No, special cases aren't special enough to break the rules."}], + [{"role": "user", "content": "What beats purity?"}, {"role": "assistant", "content": "Practicality."}], + [{"role": "user", "content": "What should never pass silently?"}, {"role": "assistant", "content": "Errors."}], + [{"role": "user", "content": "When can errors pass silently?"}, {"role": "assistant", "content": "When explicitly silenced."}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}, {"role": "assistant", "content": "Refuse the temptation to guess."}], + [{"role": "user", "content": "How many ways should there be to do it?"}, {"role": "assistant", "content": "One, and preferably only one."}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}, {"role": "assistant", "content": "Dutch."}], + [{"role": "user", "content": "What is better than never?"}, {"role": "assistant", "content": "Now is better than never."}], + [{"role": "user", "content": "Is never better than *right* now?"}, {"role": "assistant", "content": "Yes, often."}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}, {"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it may be a good idea."}], + [{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Namespaces are one honking great idea."}], + ], + "rejected": [ + [{"role": "user", "content": "What is better than ugly?"}, {"role": "assistant", "content": "Acceptable."}], + [{"role": "user", "content": "What is better than implicit?"}, {"role": "assistant", "content": "Explained."}], + [{"role": "user", "content": "What is better than complex?"}, {"role": "assistant", "content": "Very complex."}], + [{"role": "user", "content": "What is better than complicated?"}, {"role": "assistant", "content": "Very complicated."}], + [{"role": "user", "content": "What is better than nested?"}, {"role": "assistant", "content": "Circular."}], + [{"role": "user", "content": "What is better than dense?"}, {"role": "assistant", "content": "Heavy."}], + [{"role": "user", "content": "What counts?"}, {"role": "assistant", "content": "Looking complicated."}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}, {"role": "assistant", "content": "Yes, special cases are special enough to break the rules."}], + [{"role": "user", "content": "What beats purity?"}, {"role": "assistant", "content": "Nothing."}], + [{"role": "user", "content": "What should never pass silently?"}, {"role": "assistant", "content": "Warnings."}], + [{"role": "user", "content": "When can errors pass silently?"}, {"role": "assistant", "content": "Never."}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}, {"role": "assistant", "content": "Give up."}], + [{"role": "user", "content": "How many ways should there be to do it?"}, {"role": "assistant", "content": "As many as possible."}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}, {"role": "assistant", "content": "French."}], + [{"role": "user", "content": "What is better than never?"}, {"role": "assistant", "content": "Some day."}], + [{"role": "user", "content": "Is never better than *right* now?"}, {"role": "assistant", "content": "No, never."}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}, {"role": "assistant", "content": "It means it's a good idea."}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}, {"role": "assistant", "content": "It means it's a bad idea."}], + [{"role": "user", "content": "Any great ideas?"}, {"role": "assistant", "content": "Recursion."}], + ], + }) + conversational_implicit_prompt_preference_dataset = conversational_implicit_prompt_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_implicit_prompt_preference_dataset.push_to_hub(repo_id, config_name="conversational_implicit_prompt_preference") + + conversational_unpaired_preference_dataset = Dataset.from_dict({ + "prompt": [ + [{"role": "user", "content": "What is better than ugly?"}], + [{"role": "user", "content": "What is better than implicit?"}], + [{"role": "user", "content": "What is better than complex?"}], + [{"role": "user", "content": "What is better than complicated?"}], + [{"role": "user", "content": "What is better than nested?"}], + [{"role": "user", "content": "What is better than dense?"}], + [{"role": "user", "content": "What counts?"}], + [{"role": "user", "content": "Are special cases enough to break the rules?"}], + [{"role": "user", "content": "What beats purity?"}], + [{"role": "user", "content": "What should never pass silently?"}], + [{"role": "user", "content": "When can errors pass silently?"}], + [{"role": "user", "content": "What should you do in the face of ambiguity?"}], + [{"role": "user", "content": "How many ways should there be to do it?"}], + [{"role": "user", "content": "For whom may the way not be obvious at first?"}], + [{"role": "user", "content": "What is better than never?"}], + [{"role": "user", "content": "Is never better than *right* now?"}], + [{"role": "user", "content": "What does it mean if the implementation is hard to explain?"}], + [{"role": "user", "content": "What does it mean if the implementation is easy to explain?"}], + [{"role": "user", "content": "Any great ideas?"}], + ], + "completion": [ + [{'role': 'assistant', 'content': 'Beautiful.'}], + [{'role': 'assistant', 'content': 'Explicit.'}], + [{'role': 'assistant', 'content': 'Simple.'}], + [{'role': 'assistant', 'content': 'Very complicated.'}], + [{'role': 'assistant', 'content': 'Flat.'}], + [{'role': 'assistant', 'content': 'Sparse.'}], + [{'role': 'assistant', 'content': 'Readability.'}], + [{'role': 'assistant', 'content': 'Yes, special cases are special enough to break the rules.'}], + [{'role': 'assistant', 'content': 'Practicality.'}], + [{'role': 'assistant', 'content': 'Warnings.'}], + [{'role': 'assistant', 'content': 'When explicitly silenced.'}], + [{'role': 'assistant', 'content': 'Give up.'}], + [{'role': 'assistant', 'content': 'One, and preferably only one.'}], + [{'role': 'assistant', 'content': 'French.'}], + [{'role': 'assistant', 'content': 'Some day.'}], + [{'role': 'assistant', 'content': 'Yes, often.'}], + [{'role': 'assistant', 'content': "It means it's a bad idea."}], + [{'role': 'assistant', 'content': 'It means it may be a good idea.'}], + [{'role': 'assistant', 'content': 'Namespaces are one honking great idea.'}], + ], + "label": [True, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, True, True], + }) + conversational_unpaired_preference_dataset = conversational_unpaired_preference_dataset.train_test_split(test_size=test_size, shuffle=False) + if push_to_hub: + conversational_unpaired_preference_dataset.push_to_hub(repo_id, config_name="conversational_unpaired_preference") + # fmt: on + + +if __name__ == "__main__": + parser = HfArgumentParser(ScriptArguments) + script_args = parser.parse_args_into_dataclasses()[0] + main(script_args.test_size, script_args.push_to_hub, script_args.repo_id) diff --git a/scripts/log_example_reports.py b/scripts/log_example_reports.py new file mode 100644 index 0000000000000000000000000000000000000000..843a64c2093bb48a63cacffcc4ce5c6e34b1c73d --- /dev/null +++ b/scripts/log_example_reports.py @@ -0,0 +1,158 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from datetime import date + +from tabulate import tabulate + + +MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters + +parser = argparse.ArgumentParser() +parser.add_argument("--slack_channel_name", default="trl-push-examples-ci") +parser.add_argument("--text_file_name", required=True) + + +def main(text_file_name, slack_channel_name=None): + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + message = "" + + if os.path.isfile(text_file_name): + final_results = {} + + try: + with open(text_file_name) as file: + for line in file: + result, config_name = line.strip().split(",") + config_name = config_name.split("/")[-1].split(".yaml")[0] + final_results[config_name] = int(result) + except Exception as e: + logger.error(f"Error reading file {text_file_name}: {str(e)}") + final_results = {} + + no_error_payload = { + "type": "section", + "text": { + "type": "plain_text", + "text": "🌞 There were no failures on the example tests!" + if not len(final_results) == 0 + else "Something went wrong there is at least one empty file - please check GH action results.", + "emoji": True, + }, + } + + total_num_failed = sum(final_results.values()) + else: + no_error_payload = { + "type": "section", + "text": { + "type": "plain_text", + "text": "❌ Something is wrong with the workflow please check ASAP!" + "Something went wrong there is no text file being produced. Please check ASAP.", + "emoji": True, + }, + } + + total_num_failed = 0 + + test_type_name = text_file_name.replace(".txt", "").replace("temp_results_", "").replace("_", " ").title() + + payload = [ + { + "type": "header", + "text": { + "type": "plain_text", + "text": "🤗 Results of the {} TRL {} example tests.".format( + os.environ.get("TEST_TYPE", ""), test_type_name + ), + }, + }, + ] + + if total_num_failed > 0: + message += f"{total_num_failed} failed tests for example tests!" + + for test_name, failed in final_results.items(): + failed_table = tabulate( + [[test_name, "✅" if not failed else "❌"]], + headers=["Test Name", "Status"], + showindex="always", + tablefmt="grid", + maxcolwidths=[12], + ) + message += "\n```\n" + failed_table + "\n```" + + print(f"### {message}") + else: + payload.append(no_error_payload) + + if os.environ.get("TEST_TYPE", "") != "": + try: + from slack_sdk import WebClient + except ImportError: + logger.error("slack_sdk is not installed. Please install it to use Slack integration.") + return + + if len(message) > MAX_LEN_MESSAGE: + print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}") + message = message[:MAX_LEN_MESSAGE] + "..." + + if len(message) != 0: + md_report = { + "type": "section", + "text": {"type": "mrkdwn", "text": message}, + } + payload.append(md_report) + action_button = { + "type": "section", + "text": {"type": "mrkdwn", "text": "*For more details:*"}, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results", "emoji": True}, + "url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + payload.append(action_button) + + date_report = { + "type": "context", + "elements": [ + { + "type": "plain_text", + "text": f"On Push - main {os.environ.get('TEST_TYPE')} test results for {date.today()}", + }, + ], + } + payload.append(date_report) + + print(payload) + + try: + client = WebClient(token=os.environ.get("SLACK_API_TOKEN")) + response = client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload) + if response["ok"]: + logger.info("Message sent successfully to Slack.") + else: + logger.error(f"Failed to send message to Slack: {response['error']}") + except Exception as e: + logger.error(f"Error sending message to Slack: {str(e)}") + + if __name__ == "__main__": + args = parser.parse_args() + main(args.text_file_name, args.slack_channel_name) diff --git a/scripts/log_reports.py b/scripts/log_reports.py new file mode 100644 index 0000000000000000000000000000000000000000..81ec4f2365920753ef53f401b8ae80155070c793 --- /dev/null +++ b/scripts/log_reports.py @@ -0,0 +1,169 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +from datetime import date +from pathlib import Path + +from tabulate import tabulate + + +MAX_LEN_MESSAGE = 2900 # Slack endpoint has a limit of 3001 characters + +parser = argparse.ArgumentParser() +parser.add_argument("--slack_channel_name", default="trl-push-ci") + +# Set up logging +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +def process_log_file(log): + failed_tests = [] + passed_tests = [] + section_num_failed = 0 + + try: + with open(log) as f: + for line in f: + try: + data = json.loads(line) + test_name = data.get("nodeid", "") + duration = f"{data['duration']:.4f}" if "duration" in data else "N/A" + outcome = data.get("outcome", "") + + if test_name: + if outcome == "failed": + section_num_failed += 1 + failed_tests.append([test_name, duration, log.stem.split("_")[0]]) + else: + passed_tests.append([test_name, duration, log.stem.split("_")[0]]) + except json.JSONDecodeError as e: + logging.warning(f"Could not decode line in {log}: {e}") + + except FileNotFoundError as e: + logging.error(f"Log file {log} not found: {e}") + except Exception as e: + logging.error(f"Error processing log file {log}: {e}") + + return failed_tests, passed_tests, section_num_failed + + +def main(slack_channel_name): + group_info = [] + total_num_failed = 0 + total_empty_files = [] + + log_files = list(Path().glob("*.log")) + if not log_files: + logging.info("No log files found.") + return + + for log in log_files: + failed, passed, section_num_failed = process_log_file(log) + empty_file = not failed and not passed + + total_num_failed += section_num_failed + total_empty_files.append(empty_file) + group_info.append([str(log), section_num_failed, failed]) + + # Clean up log file + try: + os.remove(log) + except OSError as e: + logging.warning(f"Could not remove log file {log}: {e}") + + # Prepare Slack message payload + payload = [ + { + "type": "header", + "text": {"type": "plain_text", "text": f"🤗 Results of the {os.environ.get('TEST_TYPE', '')} TRL tests."}, + }, + ] + + if total_num_failed > 0: + message = "" + for name, num_failed, failed_tests in group_info: + if num_failed > 0: + message += f"*{name}: {num_failed} failed test(s)*\n" + failed_table = [ + test[0].split("::")[:2] + [test[0].split("::")[-1][:30] + ".."] for test in failed_tests + ] + message += ( + "\n```\n" + + tabulate(failed_table, headers=["Test Location", "Test Name"], tablefmt="grid") + + "\n```\n" + ) + + if any(total_empty_files): + message += f"\n*{name}: Warning! Empty file - check GitHub action job*\n" + + # Logging + logging.info(f"Total failed tests: {total_num_failed}") + print(f"### {message}") + + if len(message) > MAX_LEN_MESSAGE: + message = ( + f"❌ There are {total_num_failed} failed tests in total! Please check the action results directly." + ) + + payload.append({"type": "section", "text": {"type": "mrkdwn", "text": message}}) + payload.append( + { + "type": "section", + "text": {"type": "mrkdwn", "text": "*For more details:*"}, + "accessory": { + "type": "button", + "text": {"type": "plain_text", "text": "Check Action results"}, + "url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}", + }, + } + ) + payload.append( + { + "type": "context", + "elements": [ + { + "type": "plain_text", + "text": f"On Push main {os.environ.get('TEST_TYPE')} results for {date.today()}", + } + ], + } + ) + + # Send to Slack + from slack_sdk import WebClient + + slack_client = WebClient(token=os.environ.get("SLACK_API_TOKEN")) + slack_client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload) + + else: + payload.append( + { + "type": "section", + "text": { + "type": "plain_text", + "text": "✅ No failures! All tests passed successfully.", + "emoji": True, + }, + } + ) + logging.info("All tests passed. No errors detected.") + + +if __name__ == "__main__": + args = parser.parse_args() + main(args.slack_channel_name) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..5febd682d6703842dc444246e2e71828a4e3e198 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,96 @@ +[metadata] +name = trl +version = 0.19.0.dev0 +description = Train transformer language models with reinforcement learning. +long_description = file: README.md +long_description_content_type = text/markdown +author = Leandro von Werra +author_email = leandro.vonwerra@gmail.com +url = https://github.com/huggingface/trl +keywords = transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo +license_file = LICENSE +classifiers = + Development Status :: 2 - Pre-Alpha + Intended Audience :: Developers + Intended Audience :: Science/Research + Natural Language :: English + Operating System :: OS Independent + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.9 + Programming Language :: Python :: 3.10 + Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 + Programming Language :: Python :: 3.13 + +[options] +packages = find: +python_requires = >=3.9 +include_package_data = True +install_requires = + accelerate>=1.4.0 + datasets>=3.0.0 + transformers>=4.51.0 + +[options.packages.find] +exclude = + tests* + +[options.package_data] +trl = + templates/*.md + accelerate_configs/*.yaml + +[options.extras_require] +bco = + scikit-learn + joblib +deepspeed = + deepspeed>=0.14.4 +diffusers = + diffusers>=0.18.0 +judges = + openai>=1.23.2 + llm-blender>=0.0.2 +liger = + liger-kernel>=0.5.9 +peft = + peft>=0.8.0 +quantization = + bitsandbytes +scikit = + scikit-learn +test = + parameterized + pytest-cov + pytest-rerunfailures + pytest-xdist + pytest +vllm = + # vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added: + # see https://github.com/vllm-project/vllm/pull/13164 + vllm>=0.8.3; python_version < "3.13" + fastapi; python_version < "3.13" + pydantic; python_version < "3.13" + requests; python_version < "3.13" + uvicorn; python_version < "3.13" + +vlm = + Pillow +dev = + %(bco)s + %(deepspeed)s + %(diffusers)s + %(judges)s + %(liger)s + %(peft)s + %(quantization)s + %(scikit)s + %(test)s + %(vlm)s + +[options.entry_points] +console_scripts = + trl = trl.cli:main + +[coverage:run] +branch = True diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..26f52a2806c9ba33f3205b54970c37a69e259c4f --- /dev/null +++ b/setup.py @@ -0,0 +1,18 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from setuptools import setup + + +setup() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a317018578146718bb56d1b494f882305aa0535c --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/slow/__init__.py b/tests/slow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a317018578146718bb56d1b494f882305aa0535c --- /dev/null +++ b/tests/slow/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py new file mode 100644 index 0000000000000000000000000000000000000000..e6ceac5c4f7d75b80296eb7ac41501576e2cad47 --- /dev/null +++ b/tests/slow/test_dpo_slow.py @@ -0,0 +1,229 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import itertools +import tempfile +import unittest + +import pytest +import torch +from accelerate.utils.memory import release_memory +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device +from transformers.utils import is_peft_available + +from trl import DPOConfig, DPOTrainer + +from ..testing_utils import require_bitsandbytes +from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST + + +if is_peft_available(): + from peft import LoraConfig, PeftModel + + +@pytest.mark.slow +@require_torch_accelerator +@require_peft +class DPOTrainerSlowTester(unittest.TestCase): + def setUp(self): + self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + self.peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + self.max_length = 128 + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS))) + def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits): + """ + A test that tests the simple usage of `DPOTrainer` using a bare model in full precision. + """ + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + remove_unused_columns=False, + gradient_accumulation_steps=2, + learning_rate=9e-1, + eval_strategy="steps", + fp16=True, + logging_strategy="no", + report_to="none", + beta=0.1, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute_logits, + max_length=self.max_length, + ) + + # dpo train lora model + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=tokenizer, + ) + + # train the model + trainer.train() + + # save trained model or adapter + trainer.save_model() + + release_memory(model, trainer) + + @parameterized.expand( + list( + itertools.product( + MODELS_TO_TEST, + DPO_LOSS_TYPES, + DPO_PRECOMPUTE_LOGITS, + GRADIENT_CHECKPOINTING_KWARGS, + ) + ) + ) + @require_peft + def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs): + """ + A test that tests the simple usage of `DPOTrainer` using a peft model in full precision + different scenarios of gradient checkpointing. + """ + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + remove_unused_columns=False, + gradient_accumulation_steps=2, + learning_rate=9e-1, + eval_strategy="steps", + fp16=True, + logging_strategy="no", + report_to="none", + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + generate_during_eval=False, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute_logits, + beta=0.1, + max_length=self.max_length, + ) + + # dpo train lora model + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=tokenizer, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + self.assertIsNone(trainer.ref_model) + + # train the model + trainer.train() + + # save trained model or adapter + trainer.save_model() + + release_memory(model, trainer) + + @parameterized.expand( + list( + itertools.product( + MODELS_TO_TEST, + DPO_LOSS_TYPES, + DPO_PRECOMPUTE_LOGITS, + GRADIENT_CHECKPOINTING_KWARGS, + ) + ) + ) + @require_bitsandbytes + @require_peft + def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gradient_checkpointing_kwargs): + """ + A test that tests the simple usage of `DPOTrainer` using QLoRA + different scenarios of gradient checkpointing. + """ + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + remove_unused_columns=False, + gradient_accumulation_steps=2, + learning_rate=9e-1, + eval_strategy="steps", + fp16=True, + logging_strategy="no", + report_to="none", + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + beta=0.1, + generate_during_eval=False, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute_logits, + max_length=self.max_length, + ) + + # dpo train lora model + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=tokenizer, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + self.assertIsNone(trainer.ref_model) + + # train the model + trainer.train() + + # save trained model or adapter + trainer.save_model() + + release_memory(model, trainer) diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py new file mode 100644 index 0000000000000000000000000000000000000000..f61779fb0ace009c669df8d08e1fda8f80a82f40 --- /dev/null +++ b/tests/slow/test_grpo_slow.py @@ -0,0 +1,153 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import pytest +import torch +from accelerate.utils.memory import release_memory +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.testing_utils import ( + backend_empty_cache, + require_liger_kernel, + require_peft, + require_torch_accelerator, + torch_device, +) + +from trl import GRPOConfig, GRPOTrainer + +from .testing_constants import MODELS_TO_TEST + + +@pytest.mark.slow +@require_torch_accelerator +class GRPOTrainerSlowTester(unittest.TestCase): + def setUp(self): + self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test") + self.max_length = 128 + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + @parameterized.expand(MODELS_TO_TEST) + @require_liger_kernel + def test_training_with_liger_grpo_loss(self, model_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + use_liger_loss=True, + max_completion_length=self.max_length, + report_to="none", + logging_strategy="no", + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + processing_class=tokenizer, + ) + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss) + + previous_trainable_params = {n: param.clone() for n, param in model.named_parameters()} + + trainer.train() + + for n, param in previous_trainable_params.items(): + new_param = model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + release_memory(model, trainer) + + @parameterized.expand(MODELS_TO_TEST) + @require_liger_kernel + @require_peft + def test_training_with_liger_grpo_loss_and_peft(self, model_name): + from peft import LoraConfig, TaskType + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=3, + num_generations=3, + use_liger_loss=True, + max_completion_length=self.max_length, + report_to="none", + logging_strategy="no", + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token + + # Configure PEFT with LoRA + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + target_modules=["q_proj", "v_proj"], + ) + + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + processing_class=tokenizer, + peft_config=peft_config, + ) + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + + assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss) + + # Verify PEFT adapter is properly initialized + from peft import PeftModel + + self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT") + + # Store adapter weights before training + previous_trainable_params = { + n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad + } + self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model") + + trainer.train() + + # Verify adapter weights have changed after training + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + release_memory(model, trainer) diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py new file mode 100644 index 0000000000000000000000000000000000000000..052e60241610a1bc52b78c1e687dc72f0e07dfc0 --- /dev/null +++ b/tests/slow/test_sft_slow.py @@ -0,0 +1,460 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import itertools +import tempfile +import unittest + +import pytest +import torch +from accelerate.utils.memory import release_memory +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig +from transformers.testing_utils import ( + backend_empty_cache, + require_liger_kernel, + require_peft, + require_torch_accelerator, + require_torch_multi_accelerator, + torch_device, +) +from transformers.utils import is_peft_available + +from trl import SFTConfig, SFTTrainer +from trl.models.utils import setup_chat_format + +from ..testing_utils import require_bitsandbytes +from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS + + +if is_peft_available(): + from peft import LoraConfig, PeftModel + + +@pytest.mark.slow +@require_torch_accelerator +@require_peft +class SFTTrainerSlowTester(unittest.TestCase): + def setUp(self): + self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") + self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") + self.max_length = 128 + self.peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + + def tearDown(self): + gc.collect() + backend_empty_cache(torch_device) + gc.collect() + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + def test_sft_trainer_str(self, model_name, packing): + """ + Simply tests if passing a simple str to `SFTTrainer` loads and runs the trainer + as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + ) + + trainer = SFTTrainer( + model_name, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + def test_sft_trainer_transformers(self, model_name, packing): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer + as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + @require_peft + def test_sft_trainer_peft(self, model_name, packing): + """ + Simply tests if passing a transformers model + peft config to `SFTTrainer` loads and runs the trainer + as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + fp16=True, + packing=packing, + max_length=self.max_length, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + def test_sft_trainer_transformers_mp(self, model_name, packing): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer + as expected in mixed precision. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + fp16=True, # this is sufficient to enable amp + packing=packing, + max_length=self.max_length, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS))) + def test_sft_trainer_transformers_mp_gc(self, model_name, packing, gradient_checkpointing_kwargs): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer + as expected in mixed precision + different scenarios of gradient_checkpointing. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + fp16=True, # this is sufficient to enable amp + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS))) + @require_peft + def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient_checkpointing_kwargs): + """ + Simply tests if passing a transformers model + PEFT to `SFTTrainer` loads and runs the trainer + as expected in mixed precision + different scenarios of gradient_checkpointing. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + fp16=True, # this is sufficient to enable amp + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand( + list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, DEVICE_MAP_OPTIONS)) + ) + @require_torch_multi_accelerator + def test_sft_trainer_transformers_mp_gc_device_map( + self, model_name, packing, gradient_checkpointing_kwargs, device_map + ): + """ + Simply tests if passing a transformers model to `SFTTrainer` loads and runs the trainer + as expected in mixed precision + different scenarios of gradient_checkpointing (single, multi-gpu, etc). + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + fp16=True, # this is sufficient to enable amp + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS))) + @require_peft + @require_bitsandbytes + def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gradient_checkpointing_kwargs): + """ + Simply tests if passing a transformers model + PEFT + bnb to `SFTTrainer` loads and runs the trainer + as expected in mixed precision + different scenarios of gradient_checkpointing. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + packing=packing, + max_length=self.max_length, + fp16=True, # this is sufficient to enable amp + gradient_checkpointing=True, + gradient_checkpointing_kwargs=gradient_checkpointing_kwargs, + ) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + @require_peft + @require_bitsandbytes + def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): + """ + Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer + as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train") + + training_args = SFTConfig( + packing=packing, + max_length=self.max_length, + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=10, + fp16=True, + ) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16) + + model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config) + tokenizer = AutoTokenizer.from_pretrained(model_name) + + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + trainer = SFTTrainer( + model, + args=training_args, + processing_class=tokenizer, + train_dataset=train_dataset, + peft_config=self.peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + + trainer.train() + + release_memory(model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + @require_liger_kernel + def test_sft_trainer_with_liger(self, model_name, packing): + """ + Tests if passing use_liger=True to SFTConfig loads and runs the trainer + with AutoLigerKernelForCausalLM as expected. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + logging_strategy="no", + report_to="none", + per_device_train_batch_size=2, + max_steps=2, + packing=packing, + max_length=self.max_length, + use_liger_kernel=True, + ) + + trainer = SFTTrainer( + model_name, + args=training_args, + train_dataset=self.train_dataset, + eval_dataset=self.eval_dataset, + ) + + trainer.train() + + release_memory(trainer.model, trainer) + + @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) + @require_torch_accelerator + def test_train_offloading(self, model_name, packing): + """Test that activation offloading works with SFTTrainer.""" + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig( + output_dir=tmp_dir, + activation_offloading=True, + report_to="none", + per_device_train_batch_size=2, + max_steps=2, + packing=packing, + max_length=self.max_length, + ) + trainer = SFTTrainer( + model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + release_memory(trainer.model, trainer) diff --git a/tests/slow/testing_constants.py b/tests/slow/testing_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc30320c7fa826da9359b3dcf3c6e1c479e7241 --- /dev/null +++ b/tests/slow/testing_constants.py @@ -0,0 +1,26 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODELS_TO_TEST = [ + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", +] + +# We could have also not declared these variables but let's be verbose +PACKING_OPTIONS = [True, False] +GRADIENT_CHECKPOINTING_KWARGS = [None, {"use_reentrant": False}, {"use_reentrant": True}] +DEVICE_MAP_OPTIONS = [{"": 0}, "auto"] + +DPO_LOSS_TYPES = ["sigmoid", "ipo"] +DPO_PRECOMPUTE_LOGITS = [True, False] diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..9618a49452b92b49964acb8f02db98c7900b4198 --- /dev/null +++ b/tests/test_activation_offloading.py @@ -0,0 +1,156 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from torch import nn +from transformers import AutoModelForCausalLM +from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device +from transformers.utils import is_peft_available + +from trl.models.activation_offloading import NoOpManager, OffloadActivations + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestActivationOffloading(unittest.TestCase): + @require_torch_accelerator + @require_peft + def test_offloading_with_peft_models(self) -> None: + """Test that activation offloading works with PEFT models.""" + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config = LoraConfig( + lora_alpha=16, + lora_dropout=0.1, + r=8, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, peft_config) + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # First forward-backward pass without offloading + torch.manual_seed(42) + loss = model(inp, labels=inp).loss + loss.backward() + + # Store gradients - only from trainable parameters + grads_original = [] + for name, param in model.named_parameters(): + if param.requires_grad and param.grad is not None: + grads_original.append((name, param.grad.clone())) + + # Reset gradients + for p in model.parameters(): + if p.grad is not None: + p.grad = None + + # Second forward-backward pass with offloading + torch.manual_seed(42) + with OffloadActivations(): + loss_c = model(inp, labels=inp).loss + loss_c.backward() + + # Compare gradients - only trainable parameters + for name_orig, grad_orig in grads_original: + for name_param, param in model.named_parameters(): + if name_param == name_orig and param.requires_grad and param.grad is not None: + self.assertTrue( + torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), + f"Gradient mismatch for {name_orig}", + ) + + @require_torch_accelerator + def test_noop_manager_with_offloading(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # Run with offloading but disable for specific section + with OffloadActivations(): + # First forward-backward with normal offloading + torch.manual_seed(42) + out1 = model(inp, labels=inp) + out1.loss.backward() + grads1 = [p.grad.clone() for p in model.parameters()] + + # Reset grads + for p in model.parameters(): + p.grad = None + + # Second forward-backward with NoOpManager + with NoOpManager(): + torch.manual_seed(42) + out2 = model(inp, labels=inp) + out2.loss.backward() + + grads2 = [p.grad.clone() for p in model.parameters()] + + # Gradients should match as NoOpManager should have prevented offloading + for g1, g2 in zip(grads1, grads2): + self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)) + + @require_torch_accelerator + def test_min_offload_size(self): + """Test that tensors smaller than min_offload_size aren't offloaded""" + model = nn.Sequential( + nn.Linear(5, 5), # Small layer that shouldn't be offloaded + nn.Linear(5, 1000), # Large layer that should be offloaded + ).to(torch_device) + + inp = torch.randn(2, 5, device=torch_device) + + with OffloadActivations(min_offload_size=1000): + out = model(inp) + out.sum().backward() + + # The test passes if no errors occur, as we're mainly testing + # that the logic handles both offloaded and non-offloaded tensors + + @require_torch_accelerator + def test_real_hf_model(self): + """Test with an actual HuggingFace model""" + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + + # Create small input + inp = torch.randint(0, 100, (2, 10), device=torch_device) + + # Baseline without offloading + torch.manual_seed(42) + out1 = model(inp, labels=inp).loss + out1.backward() + grads1 = [p.grad.clone() for p in model.parameters()] + + # Reset grads + for p in model.parameters(): + p.grad = None + + # With offloading + with OffloadActivations(): + torch.manual_seed(42) + out2 = model(inp, labels=inp).loss + out2.backward() + + grads2 = [p.grad.clone() for p in model.parameters()] + + # Check outputs and gradients match + self.assertTrue(torch.allclose(out1, out2, rtol=1e-5)) + for g1, g2 in zip(grads1, grads2): + self.assertTrue(torch.allclose(g1, g2, rtol=1e-5)) diff --git a/tests/test_alignprop_trainer.py b/tests/test_alignprop_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6b9df643bf0c20adbc1ace17b7b5bbcc9d0f579f --- /dev/null +++ b/tests/test_alignprop_trainer.py @@ -0,0 +1,93 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import pytest +import torch +from parameterized import parameterized +from transformers.utils import is_peft_available + +from trl.import_utils import is_diffusers_available + +from .testing_utils import require_diffusers + + +if is_diffusers_available() and is_peft_available(): + from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline + + +def scorer_function(images, prompts, metadata): + return torch.randn(1) * 3.0, {} + + +def prompt_function(): + return ("cabbages", {}) + + +@pytest.mark.low_priority +@require_diffusers +class AlignPropTrainerTester(unittest.TestCase): + """ + Test the AlignPropTrainer class. + """ + + def setUp(self): + training_args = AlignPropConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + train_batch_size=2, + truncated_backprop_rand=False, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + pipeline_with_lora = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True + ) + pipeline_without_lora = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False + ) + self.trainer_with_lora = AlignPropTrainer(training_args, scorer_function, prompt_function, pipeline_with_lora) + self.trainer_without_lora = AlignPropTrainer( + training_args, scorer_function, prompt_function, pipeline_without_lora + ) + + def tearDown(self) -> None: + gc.collect() + + @parameterized.expand([True, False]) + def test_generate_samples(self, use_lora): + trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora + output_pairs = trainer._generate_samples(2, with_grad=True) + self.assertEqual(len(output_pairs.keys()), 3) + self.assertEqual(len(output_pairs["images"]), 2) + + @parameterized.expand([True, False]) + def test_calculate_loss(self, use_lora): + trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora + sample = trainer._generate_samples(2) + + images = sample["images"] + prompts = sample["prompts"] + + self.assertTupleEqual(images.shape, (2, 3, 128, 128)) + self.assertEqual(len(prompts), 2) + + rewards = trainer.compute_rewards(sample) + loss = trainer.calculate_loss(rewards) + + self.assertTrue(torch.isfinite(loss.cpu())) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a625d1e7452c284b0b6cc44a4057650551decb --- /dev/null +++ b/tests/test_bco_trainer.py @@ -0,0 +1,451 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from functools import partial + +import torch +from accelerate import Accelerator +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import BCOConfig, BCOTrainer +from trl.trainer.bco_trainer import _process_tokens, _tokenize + +from .testing_utils import require_no_wandb, require_sklearn + + +if is_peft_available(): + from peft import LoraConfig + + +class BCOTrainerTester(unittest.TestCase): + @parameterized.expand([("standard_unpaired_preference"), ("conversational_unpaired_preference")]) + @require_sklearn + def test_train(self, config_name): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # increase the learning rate to speed up the test + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + @require_sklearn + def test_train_with_precompute(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # increase the learning rate to speed up the test + precompute_ref_log_probs=True, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + @require_sklearn + def test_train_eval(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + @require_sklearn + def test_init_with_ref_model_is_model(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + with self.assertRaises(ValueError): + BCOTrainer( + model=model, + ref_model=model, # ref_model can't be the same as model + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + @require_sklearn + def test_tokenize_and_process_tokens(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + tokenized_dataset = dataset.map( + _tokenize, + fn_kwargs={"tokenizer": trainer.tokenizer}, + batched=True, + batch_size=2, + ) + self.assertListEqual(tokenized_dataset["prompt"], dataset["prompt"]) + self.assertListEqual(tokenized_dataset["completion"], dataset["completion"]) + self.assertListEqual(tokenized_dataset["label"], dataset["label"]) + self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) + self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) + self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) + self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": trainer.is_encoder_decoder, + "tokenizer": trainer.tokenizer, + "max_length": trainer.max_length, + "truncation_mode": trainer.truncation_mode, + "label_pad_token_id": trainer.label_pad_token_id, + "max_prompt_length": trainer.max_prompt_length, + } + processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) + self.assertListEqual(processed_dataset["prompt"], dataset["prompt"]) + self.assertListEqual(processed_dataset["completion"], dataset["completion"]) + self.assertListEqual(processed_dataset["label"], dataset["label"]) + self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) + self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) + self.assertListEqual( + processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645] + ) + self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) + self.assertListEqual( + processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645] + ) + + @require_sklearn + def test_train_without_providing_ref_model(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # increase the learning rate to speed up the test + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + @require_sklearn + def test_train_udm(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # Get embedding model + embedding_model_id = "trl-internal-testing/tiny-BartModel" + embedding_model = AutoModel.from_pretrained(embedding_model_id) + embedding_tokenizer = AutoTokenizer.from_pretrained(embedding_model_id) + + def embed_prompt(input_ids, attention_mask, model): + outputs = model(input_ids=input_ids, attention_mask=attention_mask) + + return outputs.last_hidden_state.mean(dim=1) + + embedding_model = Accelerator().prepare_model(embedding_model) + embedding_func = partial(embed_prompt, model=embedding_model) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # increase the learning rate to speed up the test + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + embedding_func=embedding_func, + embedding_tokenizer=embedding_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + @require_sklearn + @require_peft + def test_train_without_providing_ref_model_with_lora(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + learning_rate=0.1, # increase the learning rate to speed up the test + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + + @require_sklearn + @require_no_wandb + def test_generate_during_eval_no_wandb(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + generate_during_eval=True, + report_to="none", + ) + + with self.assertRaisesRegex( + ValueError, + expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): + BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + @require_sklearn + @require_peft + def test_lora_train_and_save(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + # assert that the model is loaded without giving OSError + AutoModelForCausalLM.from_pretrained(tmp_dir) + + @require_sklearn + def test_compute_metrics(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, # warning raised if not set to False + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + + trainer = BCOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..59ff418655027f71b2891dbf7b7686db38cb592f --- /dev/null +++ b/tests/test_best_of_n_sampler.py @@ -0,0 +1,112 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer, GenerationConfig + +from trl import AutoModelForCausalLMWithValueHead +from trl.core import LengthSampler +from trl.extras import BestOfNSampler + + +def queries_to_scores(list_of_strings): + return [torch.rand(1).item() for _ in list_of_strings] + + +class BestOfNSamplerTester(unittest.TestCase): + """ + Tests the BestOfNSampler class + """ + + ref_model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + output_length_sampler = LengthSampler(2, 6) + model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) + tokenizer = AutoTokenizer.from_pretrained(ref_model_name) + tokenizer.pad_token = tokenizer.eos_token + output_length_sampler = LengthSampler(2, 6) + + def test_different_input_types(self): + r""" + Tests if the different input types normalizer works + """ + + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(2, 6) + + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + ) + + queries = ["hello world", "goodbye world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + + various_queries_formats = [ + (tokenized_queries[0], 1), + (tokenized_queries, 2), + (torch.tensor(tokenized_queries[1]), 1), + ([torch.tensor(query) for query in tokenized_queries], 2), + ] + + for q, expected_length in various_queries_formats: + results = best_of_n.generate(q) + self.assertIsInstance(results, list) + self.assertEqual(len(results), expected_length) + + def test_different_sample_sizes_and_n_candidates_values(self): + r""" + Tests different sample sizes and n_candidates values + """ + generation_config = GenerationConfig( + min_length=-1, + top_k=0.0, + top_p=1.0, + do_sample=True, + pad_token_id=self.tokenizer.eos_token_id, + ) + + output_length_sampler = LengthSampler(6, 10) + + for sample_value, n_candidates_values, expected in [ + (4, 2, 2), + (10, 3, 3), + (6, 4, 4), + ]: + best_of_n = BestOfNSampler( + self.model, + self.tokenizer, + queries_to_scores, + length_sampler=output_length_sampler, + generation_config=generation_config, + sample_size=sample_value, + n_candidates=n_candidates_values, + ) + + queries = ["hello world", "troll the world"] + tokenized_queries = [self.tokenizer.encode(query) for query in queries] + results = best_of_n.generate(tokenized_queries) + for result in results: + self.assertEqual(len(result), expected) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..7318793a019d0bc814d5a686b806336c5e167966 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,371 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import tempfile +import unittest + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments +from transformers.testing_utils import require_peft, require_wandb +from transformers.trainer_utils import get_last_checkpoint +from transformers.utils import is_peft_available + +from tests.testing_utils import require_comet, require_mergekit +from trl import BasePairwiseJudge, DPOConfig, DPOTrainer, LogCompletionsCallback, MergeModelCallback, WinRateCallback +from trl.mergekit_utils import MergeConfig + + +if is_peft_available(): + from peft import LoraConfig + + +class HalfPairwiseJudge(BasePairwiseJudge): + """Naive pairwise judge that always returns [1, 0] for two prompts""" + + def judge(self, prompts, completions, shuffle_order=True, return_scores=False): + # just check that the batch size is 2 + assert len(prompts) == 2 + if return_scores: + return [0.3, 0.9] + return [1, 0] + + +class TrainerWithRefModel(Trainer): + # This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional + # ref_model attribute + def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class): + super().__init__( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + ) + self.ref_model = ref_model + + +class WinRateCallbackTester(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer.pad_token = self.tokenizer.eos_token + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dataset["train"] = dataset["train"].select(range(8)) + self.expected_winrates = [ + {"eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, + {"eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, + {"eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, + {"eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, + {"eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, + {"eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, + {"eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, + ] + + def tokenize_function(examples): + out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) + out["labels"] = out["input_ids"].copy() + return out + + self.dataset = dataset.map(tokenize_function, batched=True) + + self.generation_config = GenerationConfig(max_length=32) + self.judge = HalfPairwiseJudge() + + def test_basic(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = TrainerWithRefModel( + model=self.model, + ref_model=self.ref_model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + self.assertListEqual(winrate_history, self.expected_winrates) + + def test_without_ref_model(self): + # Same as before, but without the ref_model attribute. It should use the model attribute instead + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + self.assertListEqual(winrate_history, self.expected_winrates) + + def test_soft_judge(self): + """Test that the soft judge functionality works correctly""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = TrainerWithRefModel( + model=self.model, + ref_model=self.ref_model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True + ) + trainer.add_callback(win_rate_callback) + trainer.train() + + # Expected values based on judge returning [0.3, 0.9] for each pair + expected_soft_winrates = [ + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10}, + {"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12}, + ] + + winrate_history = [ + {k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]} + for h in trainer.state.log_history + if "eval_avg_win_prob" in h + ] + self.assertListEqual(winrate_history, expected_soft_winrates) + + @require_peft + def test_lora(self): + with tempfile.TemporaryDirectory() as tmp_dir: + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + self.model.add_adapter(peft_config) + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="none", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + win_rate_callback = WinRateCallback( + judge=self.judge, trainer=trainer, generation_config=self.generation_config + ) + trainer.add_callback(win_rate_callback) + trainer.train() + winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] + self.assertListEqual(winrate_history, self.expected_winrates) + + +class LogCompletionsCallbackTester(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer.pad_token = self.tokenizer.eos_token + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + dataset["train"] = dataset["train"].select(range(8)) + + def tokenize_function(examples): + out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True) + out["labels"] = out["input_ids"].copy() + return out + + self.dataset = dataset.map(tokenize_function, batched=True) + + self.generation_config = GenerationConfig(max_length=32) + + @require_wandb + def test_basic_wandb(self): + import wandb + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="wandb", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) + trainer.add_callback(completions_callback) + trainer.train() + + # Get the current run + completions_path = wandb.run.summary.completions["path"] + json_path = os.path.join(wandb.run.dir, completions_path) + with open(json_path) as f: + completions = json.load(f) + + # Check that the columns are correct + self.assertIn("step", completions["columns"]) + self.assertIn("prompt", completions["columns"]) + self.assertIn("completion", completions["columns"]) + + # Check that the prompt is in the log + self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) + + @require_comet + def test_basic_comet(self): + import comet_ml + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + eval_strategy="steps", + eval_steps=2, # evaluate every 2 steps + per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch + per_device_eval_batch_size=2, + report_to="comet_ml", + ) + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=self.dataset["train"], + eval_dataset=self.dataset["test"], + processing_class=self.tokenizer, + ) + completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2) + trainer.add_callback(completions_callback) + trainer.train() + + # close experiment to make sure all pending data are flushed + experiment = comet_ml.get_running_experiment() + assert experiment is not None + experiment.end() + + # get experiment assets and check that all required tables was logged + steps = len(self.dataset["train"]) + len(self.dataset["test"]) + tables_logged = int(steps / 2) + 1 # +1 to include zero step + + api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id) + tables = api_experiment.get_asset_list("dataframe") + assert tables is not None + assert len(tables) == tables_logged + assert all(table["fileName"] == "completions.csv" for table in tables) + + +@require_mergekit +class MergeModelCallbackTester(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + + def test_callback(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + num_train_epochs=1, + report_to="none", + save_strategy="steps", + save_steps=1, + ) + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer( + model=self.model, + args=training_args, + train_dataset=self.dataset, + processing_class=self.tokenizer, + callbacks=[merge_callback], + ) + trainer.train() + last_checkpoint = get_last_checkpoint(tmp_dir) + merged_path = os.path.join(last_checkpoint, "merged") + self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") + + def test_every_checkpoint(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + num_train_epochs=1, + report_to="none", + save_strategy="steps", + save_steps=1, + ) + config = MergeConfig() + merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) + trainer = DPOTrainer( + model=self.model, + args=training_args, + train_dataset=self.dataset, + processing_class=self.tokenizer, + callbacks=[merge_callback], + ) + trainer.train() + + checkpoints = sorted( + [os.path.join(tmp_dir, cp) for cp in os.listdir(tmp_dir) if cp.startswith("checkpoint-")] + ) + + for checkpoint in checkpoints: + merged_path = os.path.join(checkpoint, "merged") + self.assertTrue( + os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." + ) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..6933273b8acf5b896f6f241bdae836a65caf65a6 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,102 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import sys +import tempfile +import unittest +from io import StringIO +from unittest.mock import patch + +import yaml + + +@unittest.skipIf( + sys.version_info < (3, 10), + "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " + "to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche +) +class TestCLI(unittest.TestCase): + def test_dpo(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl dpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + @patch("sys.stdout", new_callable=StringIO) + def test_env(self, mock_stdout): + from trl.cli import main + + command = "trl env" + with patch("sys.argv", command.split(" ")): + main() + self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) + + def test_grpo(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl grpo --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --reward_model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_prompt_only --num_generations 4 --max_completion_length 32 --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + def test_kto(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl kto --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_unpaired_preference --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + def test_sft(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + command = f"trl sft --output_dir {tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_language_modeling --report_to none" + with patch("sys.argv", command.split(" ")): + main() + + def test_sft_config_file(self): + from trl.cli import main + + with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory + output_dir = os.path.join(tmp_dir, "output") + + # Create a temporary config file + config_path = os.path.join(tmp_dir, "config.yaml") + config_content = { + "model_name_or_path": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "dataset_name": "trl-internal-testing/zen", + "dataset_config": "standard_language_modeling", + "report_to": "none", + "output_dir": output_dir, + "lr_scheduler_type": "cosine_with_restarts", + } + with open(config_path, "w") as config_file: + yaml.dump(config_content, config_file) + + # Test the CLI with config file + command = f"trl sft --config {config_path}" + with patch("sys.argv", command.split(" ")): + main() + + # Verify that output directory was created + self.assertTrue(os.path.exists(output_dir)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..24a61650046297cea59863122cc8512811c52a70 --- /dev/null +++ b/tests/test_cli_utils.py @@ -0,0 +1,264 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from dataclasses import dataclass +from unittest.mock import mock_open, patch + +from trl import TrlParser + + +@dataclass +class MyDataclass: + arg1: int + arg2: str = "default" + + +@dataclass +class InvalidDataclass: + config: str # This should raise an error in the TrlParser + + +class TestTrlParser(unittest.TestCase): + def test_init_without_config_field(self): + """Test initialization without 'config' field in the dataclasses.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + self.assertIsInstance(parser, TrlParser) + + def test_init_with_config_field(self): + """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" + with self.assertRaises(ValueError) as context: + TrlParser(dataclass_types=[InvalidDataclass]) + self.assertTrue("has a field named 'config'" in str(context.exception)) + + @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) + @patch("yaml.safe_load") + @patch("os.environ", new_callable=dict) # Mock os.environ as a dictionary + def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_load): + """Test parse_args_and_config method with valid arguments and config.""" + mock_yaml_load.return_value = {"env": {"VAR1": "value1", "VAR2": "value2"}, "arg1": 2} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg2", "value", "--config", "config.yaml"] # don't set arg1 to test default value + + # Simulate the config being loaded and environment variables being set + result_args = parser.parse_args_and_config(args) + + # Set the environment variables using the mock + mock_environ["VAR1"] = "value1" + mock_environ["VAR2"] = "value2" + + # Ensure that the environment variables were set correctly + self.assertEqual(mock_environ.get("VAR1"), "value1") + self.assertEqual(mock_environ.get("VAR2"), "value2") + + # Check the parsed arguments + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + + @patch("builtins.open", mock_open(read_data="arg1: 2")) + @patch("yaml.safe_load") + def test_parse_args_and_arg_override_config(self, mock_yaml_load): + """Test parse_args_and_config method and check that arguments override the config.""" + mock_yaml_load.return_value = {"arg1": 2} # this arg is meant to be overridden + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "3", "--config", "config.yaml"] # override arg1 default with 3 + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args) + + # Check the parsed arguments + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 3) + + @patch("builtins.open", mock_open(read_data="env: not_a_dict")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): + """Test parse_args_and_config method when the 'env' field is not a dictionary.""" + mock_yaml_load.return_value = {"env": "not_a_dict"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] + + with self.assertRaises(ValueError) as context: + parser.parse_args_and_config(args) + + self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.") + + def test_parse_args_and_config_without_config(self): + """Test parse_args_and_config without the `--config` argument.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + + def test_set_defaults_with_config(self): + """Test set_defaults_with_config updates the defaults.""" + parser = TrlParser(dataclass_types=[MyDataclass]) + + # Update defaults + parser.set_defaults_with_config(arg1=42) + + # Ensure the default value is updated + result_args = parser.parse_args_and_config([]) + self.assertEqual(len(result_args), 1) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 42) + + def test_parse_args_and_config_with_remaining_strings(self): + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--arg2", "value", "remaining"] + + # Simulate no config, just parse args normally + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 2) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[0].arg2, "value") + self.assertEqual(result_args[1], ["remaining"]) + + @patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc")) + @patch("yaml.safe_load") + def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, mock_yaml_load): + mock_yaml_load.return_value = {"remaining_string_in_config": "abc"} + + parser = TrlParser(dataclass_types=[MyDataclass]) + + args = ["--arg1", "2", "--remaining_string_in_args", "def", "--config", "config.yaml"] + + # Simulate the config being loaded and arguments being passed + result_args = parser.parse_args_and_config(args, return_remaining_strings=True) + + # Check that the arguments are parsed as is + self.assertEqual(len(result_args), 2) + self.assertIsInstance(result_args[0], MyDataclass) + self.assertEqual(result_args[0].arg1, 2) + self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]) + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Parse with config file + args = ["subcommand", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + self.assertEqual(len(result_args), 1) + + # Check that config values were applied to the subparser + self.assertEqual(result_args[0].arg1, 2) # Default from config + self.assertEqual(result_args[0].arg2, "config_value") # Default from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Test with command line arguments overriding config + args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Command line arguments should override config + self.assertEqual(result_args[0].arg1, 3) + self.assertEqual(result_args[0].arg2, "config_value") # Still from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\nthis_arg_does_not_exist: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults_and_arg_override_wrong_name(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "this_arg_does_not_exist": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Test with command line arguments overriding config + args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] + with self.assertRaises(ValueError): + parser.parse_args_and_config(args) + + parser.parse_args_and_config(args, fail_with_unknown_args=False) + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass]) + subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass]) + + for idx in range(2): + # Parse with config file + args = [f"subcommand{idx}", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + self.assertEqual(len(result_args), 1) + + # Check that config values were applied to the subparser + self.assertEqual(result_args[0].arg1, 2) # Default from config + self.assertEqual(result_args[0].arg2, "config_value") # Default from config diff --git a/tests/test_collators.py b/tests/test_collators.py new file mode 100644 index 0000000000000000000000000000000000000000..a55265f877702902760d670479dcbd0ecfc3cddf --- /dev/null +++ b/tests/test_collators.py @@ -0,0 +1,74 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from trl.trainer.dpo_trainer import DataCollatorForPreference + + +class TestDataCollatorForPreference(unittest.TestCase): + def setUp(self): + self.collator = DataCollatorForPreference(pad_token_id=0) + + def assertTensorEqual(self, tensor1, tensor2): + self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") + + def test_padding_behavior(self): + examples = [ + {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, + {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, + ] + output = self.collator.torch_call(examples) + + expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]]) + expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) + expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]]) + expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]]) + expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]]) + expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]]) + + self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids) + self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask) + self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids) + self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask) + self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids) + self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask) + + def test_optional_fields(self): + examples = [ + { + "prompt_input_ids": [1], + "chosen_input_ids": [2], + "rejected_input_ids": [3], + "pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2) + }, + { + "prompt_input_ids": [4], + "chosen_input_ids": [5], + "rejected_input_ids": [6], + "pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2) + }, + ] + output = self.collator.torch_call(examples) + + expected_pixel_values = torch.tensor( + [ + [[[0.1, 0.2], [0.3, 0.4]]], + [[[0.5, 0.6], [0.7, 0.8]]], + ] + ) # Shape: (2, 1, 2, 2) + + self.assertTensorEqual(output["pixel_values"], expected_pixel_values) diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..959ea23964302cbba2f838c8543257acfcaf53c9 --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,46 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from trl.core import masked_mean, masked_var, masked_whiten + + +class CoreTester(unittest.TestCase): + """ + A wrapper class for testing core utils functions + """ + + def setUp(self): + self.test_input = torch.Tensor([1, 2, 3, 4]) + self.test_mask = torch.Tensor([0, 1, 1, 0]) + self.test_input_unmasked = self.test_input[1:3] + + def test_masked_mean(self): + self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) + + def test_masked_var(self): + self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) + + def test_masked_whiten(self): + def whiten(values: torch.Tensor) -> torch.Tensor: + mean, var = torch.mean(values), torch.var(values) + return (values - mean) * torch.rsqrt(var + 1e-8) + + whiten_unmasked = whiten(self.test_input_unmasked) + whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] + diffs = (whiten_unmasked - whiten_masked).sum() + self.assertLess(abs(diffs.item()), 0.00001) diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..226fb5004c6c4b0ad18c211ec6d08337ed4cd5af --- /dev/null +++ b/tests/test_cpo_trainer.py @@ -0,0 +1,189 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.testing_utils import require_peft + +from trl import CPOConfig, CPOTrainer +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +class CPOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + @parameterized.expand( + [ + ("qwen", "sigmoid", "standard_preference"), + ("t5", "hinge", "standard_implicit_prompt_preference"), + ("qwen", "ipo", "conversational_preference"), + ("t5", "ipo", "conversational_implicit_prompt_preference"), + ("qwen", "simpo", "standard_preference"), + ("t5", "simpo", "standard_implicit_prompt_preference"), + ("qwen", "hinge", "conversational_preference"), + ] + ) + def test_cpo_trainer(self, name, loss_type, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + loss_type=loss_type, + cpo_alpha=1.0, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = CPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + @parameterized.expand( + [ + ("standard_preference",), + ("standard_implicit_prompt_preference",), + ("conversational_preference",), + ("conversational_implicit_prompt_preference",), + ] + ) + @require_peft + def test_cpo_trainer_with_lora(self, config_name): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + cpo_alpha=1.0, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = CPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + remove_unused_columns=False, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = CPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) diff --git a/tests/test_data_collator_completion_only.py b/tests/test_data_collator_completion_only.py new file mode 100644 index 0000000000000000000000000000000000000000..260bcfe93b2de2314515e426e6911c53ed2c2cbd --- /dev/null +++ b/tests/test_data_collator_completion_only.py @@ -0,0 +1,169 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoTokenizer + +from trl import DataCollatorForCompletionOnlyLM + + +class DataCollatorForCompletionOnlyLMTester(unittest.TestCase): + def test_data_collator_finds_response_template_llama2_tokenizer(self): + # this should ideally be tested with meta-llama/Llama-2-7b-hf + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? + +### Assistant: 2+2 equals 4""" + self.instruction_template = "\n### User:" + self.response_template = "\n### Assistant:" + + # GPT2Tokenizer: [198, 21017, 11787, 25] -> [21017, 11787, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4911, 29901] -> [2277, 29937, 4911, 29901] + # Note: If this test is ever switched to Llama2Tokenizer, this should be double checked, + # and possibly switched back to [2:] instead of [1:]. + # With GPT2Tokenizer, [1:] is correct - we want the 21017 token included, which is ###. + self.tokenized_instruction_w_context = self.tokenizer.encode( + self.instruction_template, add_special_tokens=False + )[1:] + + # GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25] + # Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901] + self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:] + + # Plain check on string + self.assertIn(self.response_template, self.instruction) + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + + # Test the fix for #598 + # Pass already tokenized (w context) and truncated response_template so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer) + self.collator.torch_call([self.tokenized_instruction]) + + # Test for PR #749 + # Pass already tokenized (w context) instruction and response both so token_ids are like in the instruction + response + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + self.collator.torch_call([self.tokenized_instruction]) + + # Test for PR #1185 + # We pass in a string where the first user template is different than the rest. + # Usually this would happen due to context-sensitive tokenization, but here we + # explicitly change the template to test the fix. + self.instruction = """## User: First instruction + +### Assistant: First response + +### User: Second instruction + +### Assistant: Second response""" + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM( + self.tokenized_response_w_context, self.tokenized_instruction_w_context, tokenizer=self.tokenizer + ) + collator_output = self.collator.torch_call([self.tokenized_instruction]) + collator_text = self.tokenizer.decode( + collator_output["labels"][torch.where(collator_output["labels"] != -100)] + ) + expected_text = " First response\n\n Second response" + self.assertEqual(collator_text, expected_text) + + def test_data_collator_handling_of_long_sequences(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.instruction = """### System: You are a helpful assistant. + +### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math. +""" + self.response_template = "\n### Assistant:" + # check DataCollatorForCompletionOnlyLM using response template only + self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False) + self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer) + + with self.assertWarns(UserWarning): # it should raise a warning since the response_template isn't found + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") + + # check DataCollatorForCompletionOnlyLM using response template and instruction template + self.instruction_template = "\n### User:" + self.collator = DataCollatorForCompletionOnlyLM( + self.response_template, self.instruction_template, tokenizer=self.tokenizer + ) + with self.assertWarns(UserWarning): # it should raise a warning since the response_template isn't found + encoded_instance = self.collator.torch_call([self.tokenized_instruction]) + result = torch.all(encoded_instance["labels"] == -100) + self.assertTrue(result, "Not all values in the tensor are -100.") + + def test_padding_free(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + inst1 = "### System: You are a helpful assistant.\n\n### User: How much is 2+2?\n\n### Assistant: 2+2 equals 4" + inst2 = "### System: You are a honest and helpful assistant.\n\n### User: What is the answer of 22x22?\n\n### Assistant: 22x22 equals 484" + + response_template = "### Assistant:" + collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + collator_paddingfree = DataCollatorForCompletionOnlyLM( + response_template, tokenizer=tokenizer, padding_free=True + ) + + tokenized_instruction = [tokenizer(x, add_special_tokens=False) for x in [inst1, inst2]] + batch = collator(tokenized_instruction) + batch_paddingfree = collator_paddingfree(tokenized_instruction) + + self.assertNotIn("attention_mask", batch_paddingfree) + self.assertIn("input_ids", batch_paddingfree) + self.assertIn("labels", batch_paddingfree) + self.assertIn("position_ids", batch_paddingfree) + self.assertEqual(batch_paddingfree["input_ids"].size(), batch_paddingfree["labels"].size()) + self.assertEqual(batch_paddingfree["labels"].size(), batch_paddingfree["position_ids"].size()) + + attn_mask = batch["attention_mask"] + input_ids_remove_pad = batch["input_ids"][attn_mask.bool()].unsqueeze(0) + expected_position_ids = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1 + expected_labels = [] + for idx in range(batch["input_ids"].size(0)): + expected_labels.append(batch["labels"][idx][attn_mask[idx].bool()]) + expected_labels[-1][0] = collator.ignore_index + expected_labels = torch.cat(expected_labels).unsqueeze(0) + + self.assertTrue((input_ids_remove_pad == batch_paddingfree["input_ids"]).all()) + self.assertTrue((expected_position_ids == batch_paddingfree["position_ids"]).all()) + self.assertTrue((expected_labels == batch_paddingfree["labels"]).all()) + + def test_data_collator_for_completion_only_lm(self): + # The tokenizer isn't use but the collator needs it to be provided. + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + collator = DataCollatorForCompletionOnlyLM(tokenizer.decode(9999), tokenizer=tokenizer, padding_free=True) + + tokenized_instruction = [ + {"input_ids": [1, 2, 3, 9999, 4, 5], "attention_mask": [1, 1, 1, 1, 1, 1]}, + {"input_ids": [6, 7, 8, 9, 9999, 10, 11], "attention_mask": [1, 1, 1, 1, 1, 1, 1]}, + ] + batch = collator(tokenized_instruction) + + self.assertEqual(batch["position_ids"].tolist(), [[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]]) # flat pos ids + self.assertEqual( + batch["cu_seq_lens_q"].tolist(), [[0, 6, 13]] + ) # start idx of each seq + total number of tokens + self.assertEqual(batch["cu_seq_lens_k"].tolist(), [[0, 6, 13]]) # idem + self.assertEqual(batch["max_length_k"], torch.tensor([7])) # max length in batch, here 7 (second sequence) + self.assertEqual(batch["max_length_q"], torch.tensor([7])) # idem diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6fc58d50f87c6ecf37b388f5f61bd6d82cea712c --- /dev/null +++ b/tests/test_data_utils.py @@ -0,0 +1,617 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import unittest + +from datasets import Dataset, DatasetDict +from parameterized import parameterized +from transformers import AutoProcessor, AutoTokenizer + +from trl.data_utils import ( + apply_chat_template, + extract_prompt, + is_conversational, + maybe_apply_chat_template, + maybe_convert_to_chatml, + maybe_extract_prompt, + maybe_unpair_preference_dataset, + pack_dataset, + pack_examples, + truncate_dataset, + unpair_preference_dataset, +) + + +class IsConversationalTester(unittest.TestCase): + conversational_examples = [ + { # Language modeling + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + }, + { # Prompt only + "prompt": [{"role": "user", "content": "What color is the sky?"}], + }, + { # Prompt-completion + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + }, + { # Preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], + }, + { # Preference with implicit prompt + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + }, + { # Unpaired preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True, + }, + ] + + non_conversational_examples = [ + {"prompt": "The sky is", "completion": " blue."}, + {"text": "The sky is blue."}, + {"prompt": "The sky is"}, + {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, + {"prompt": "The sky is", "completion": " blue.", "label": True}, + ] + + @parameterized.expand(itertools.product(conversational_examples)) + def test_conversational(self, example): + self.assertTrue(is_conversational(example)) + + @parameterized.expand(itertools.product(non_conversational_examples)) + def test_non_conversational(self, example): + self.assertFalse(is_conversational(example)) + + +class ApplyChatTemplateTester(unittest.TestCase): + tokenizers = [ + "trl-internal-testing/tiny-CohereForCausalLM", + "trl-internal-testing/tiny-DbrxForCausalLM", + "trl-internal-testing/tiny-DeepseekV3ForCausalLM", + "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", + "trl-internal-testing/tiny-FalconMambaForCausalLM", + "trl-internal-testing/tiny-Gemma2ForCausalLM", + "trl-internal-testing/tiny-GemmaForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-LlamaForCausalLM-3", + "trl-internal-testing/tiny-MistralForCausalLM-0.1", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "trl-internal-testing/tiny-Qwen3ForCausalLM", + ] + + conversational_examples = [ + { # Language modeling + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + }, + { # Prompt only + "prompt": [{"role": "user", "content": "What color is the sky?"}], + }, + { # Prompt-completion + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + }, + { # Preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], + }, + { # Preference with implicit prompt + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + }, + { # Unpaired preference + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + "label": True, + }, + ] + + non_conversational_examples = [ + {"text": "The sky is blue."}, # Language modeling + {"prompt": "The sky is"}, # Prompt only + {"prompt": "The sky is", "completion": " blue."}, # Prompt-completion + {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference + {"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt + {"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference + ] + + @parameterized.expand(itertools.product(tokenizers, conversational_examples)) + def test_apply_chat_template(self, tokenizer_id, example): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + result = apply_chat_template(example, tokenizer) + + # Checking if the result is a dictionary + self.assertIsInstance(result, dict) + + # The chat template should be applied to the following keys + for key in ["prompt", "chosen", "rejected", "completion"]: + if key in example: + self.assertIn(key, result) + self.assertIsInstance(result[key], str) + + # Exception for messages, the key is "text" once the chat template is applied + if "messages" in example: + self.assertIn("text", result) + self.assertIsInstance(result["text"], str) + + # The label should be kept + if "label" in example: + self.assertIn("label", result) + self.assertIsInstance(result["label"], bool) + self.assertEqual(result["label"], example["label"]) + + # both conversational and non-conversational examples + @parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples)) + def test_maybe_apply_chat_template(self, tokenizer_id, example): + tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) + result = maybe_apply_chat_template(example, tokenizer) + + # Checking if the result is a dictionary + self.assertIsInstance(result, dict) + + # The chat template should be applied to the following keys + for key in ["prompt", "chosen", "rejected", "completion"]: + if key in example: + self.assertIn(key, result) + self.assertIsInstance(result[key], str) + + # Exception for messages, the key is "text" once the chat template is applied + if "messages" in example: + self.assertIn("text", result) + self.assertIsInstance(result["text"], str) + + # The label should be kept + if "label" in example: + self.assertIn("label", result) + self.assertIsInstance(result["label"], bool) + self.assertEqual(result["label"], example["label"]) + + def test_apply_chat_template_with_tools(self): + tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + # Define test case + test_case = { + "prompt": [ + {"content": "Whats the temperature in London?", "role": "user"}, + ] + } + # Test with tools + result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) + + # Verify tools are included in the output + self.assertIn("get_current_temperature", result_with_tools["prompt"]) + + # Test without tools + result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) + + # Verify tools are not included in the output + self.assertNotIn("get_current_temperature", result_without_tools["prompt"]) + + +class UnpairPreferenceDatasetTester(unittest.TestCase): + paired_dataset = Dataset.from_dict( + { + "prompt": ["The sky is", "The sun is"], + "chosen": [" blue.", " in the sky."], + "rejected": [" green.", " in the sea."], + } + ) + + unpaired_dataset = Dataset.from_dict( + { + "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], + "completion": [" blue.", " in the sky.", " green.", " in the sea."], + "label": [True, True, False, False], + } + ) + + def test_unpair_preference_dataset(self): + # Test that a paired dataset is correctly converted to unpaired + unpaired_dataset = unpair_preference_dataset(self.paired_dataset) + self.assertEqual( + unpaired_dataset.to_dict(), + self.unpaired_dataset.to_dict(), + "The paired dataset should be converted to unpaired.", + ) + + def test_unpair_preference_dataset_dict(self): + # Test that a paired dataset dict is correctly converted to unpaired + paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) + unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) + self.assertEqual( + unpaired_dataset_dict["abc"].to_dict(), + self.unpaired_dataset.to_dict(), + "The paired dataset should be converted to unpaired.", + ) + + def test_maybe_unpair_preference_dataset(self): + # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset + unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) + self.assertEqual( + unpaired_dataset.to_dict(), + self.unpaired_dataset.to_dict(), + "The paired dataset should be converted to unpaired.", + ) + + def test_maybe_unpair_preference_dataset_dict(self): + # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset + paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) + unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) + self.assertEqual( + unpaired_dataset_dict["abc"].to_dict(), + self.unpaired_dataset.to_dict(), + "The paired dataset should be converted to unpaired.", + ) + + def test_maybe_unpair_preference_dataset_already_paired(self): + # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset + unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) + self.assertEqual( + unpaired_dataset.to_dict(), + self.unpaired_dataset.to_dict(), + "The unpaired dataset should remain unchanged.", + ) + + def test_maybe_unpair_preference_dataset_dict_already_paired(self): + # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset + unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) + self.assertEqual( + unpaired_dataset_dict["abc"].to_dict(), + self.unpaired_dataset.to_dict(), + "The unpaired dataset should remain unchanged.", + ) + + +class ExtractPromptTester(unittest.TestCase): + example_implicit_prompt_conversational = { + "chosen": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is green."}, + ], + } + + example_explicit_prompt_conversational = { + "prompt": [ + {"role": "user", "content": "What color is the sky?"}, + ], + "chosen": [ + {"role": "assistant", "content": "It is blue."}, + ], + "rejected": [ + {"role": "assistant", "content": "It is green."}, + ], + } + + example_implicit_prompt_standard = { + "chosen": "The sky is blue.", + "rejected": "The sky is green.", + } + + example_explicit_prompt_standard = { + "prompt": "The sky is", + "chosen": " blue.", + "rejected": " green.", + } + + def test_extract_prompt_conversational(self): + # Test that the prompt is correctly extracted from the dataset + example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_conversational, + "The prompt is not correctly extracted from the dataset.", + ) + + def test_maybe_extract_prompt_conversational(self): + # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_conversational, + "The prompt is not correctly extracted from the dataset.", + ) + + def test_maybe_extract_prompt_conversational_already_explicit(self): + # Test that the prompt remains unchanged with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_conversational, + "The prompt should remain unchanged.", + ) + + def test_extract_prompt_standard(self): + # Test that the prompt is correctly extracted from the dataset + example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_standard, + "The prompt is not correctly extracted from the dataset.", + ) + + def test_maybe_extract_prompt_standard(self): + # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_standard, + "The prompt is not correctly extracted from the dataset.", + ) + + def test_maybe_extract_prompt_standard_already_explicit(self): + # Test that the prompt remains unchanged with maybe_extract_prompt + example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) + self.assertEqual( + example_extracted_prompt, + self.example_explicit_prompt_standard, + "The prompt should remain unchanged.", + ) + + +class TestPackExamples(unittest.TestCase): + def test_larger_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 5 + expected_output = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7, 8]], + "attention_mask": [[0, 1, 1, 0, 0], [1, 1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_smaller_chunks(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + seq_length = 2 + expected_output = { + "input_ids": [[1, 2], [3, 4], [5, 6], [7, 8]], + "attention_mask": [[0, 1], [1, 0], [0, 1], [1, 1]], + } + result = pack_examples(examples, seq_length) + self.assertEqual(result, expected_output) + + def test_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = dataset.map(pack_examples, batched=True, fn_kwargs={"seq_length": seq_length}) + self.assertEqual(dataset.to_dict(), expected_output) + + +class TestPackDatasetWrapped(unittest.TestCase): + def test_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") + self.assertEqual(dataset.to_dict(), expected_output) + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + seq_length = 3 + expected_output = { + "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], + "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], + } + dataset = pack_dataset(dataset, seq_length, strategy="wrapped") + num_examples = len(examples[next(iter(examples))]) + self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + + +class TestPackDatasetFfd(unittest.TestCase): + def test_simple(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], + "position_ids": [[0, 1, 2, 3], [0, 1, 2, 0]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + self.assertEqual(dataset.to_dict(), expected_output) + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + seq_length = 4 + expected_output = { + "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], + "attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]], + "position_ids": [[0, 1, 2, 3], [0, 1, 2, 0]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + num_examples = len(examples[next(iter(examples))]) + self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + + def test_with_truncation(self): + examples = { + "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], + "attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + seq_length = 4 + expected_output = { + "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]], + "attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]], + "position_ids": [[0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 0]], + } + dataset = pack_dataset(dataset, seq_length, strategy="ffd") + self.assertEqual(dataset.to_dict(), expected_output) + + +class TestTruncateExamples(unittest.TestCase): + def test_with_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples) + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + } + dataset = truncate_dataset(dataset, max_length) + self.assertEqual(dataset.to_dict(), expected_output) + + def test_with_iterable_dataset(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + } + dataset = Dataset.from_dict(examples).to_iterable_dataset() + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + } + dataset = truncate_dataset(dataset, max_length) + num_examples = len(examples[next(iter(examples))]) + self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + + def test_with_extra_column(self): + examples = { + "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + "my_column": ["a", "b", "c"], + } + dataset = Dataset.from_dict(examples) + max_length = 2 + expected_output = { + "input_ids": [[1, 2], [4, 5], [8]], + "attention_mask": [[0, 1], [0, 0], [1]], + "my_column": ["a", "b", "c"], + } + dataset = truncate_dataset(dataset, max_length) + self.assertEqual(dataset.to_dict(), expected_output) + + +class TestMaybeConvertToChatML(unittest.TestCase): + def test_with_conversations_key(self): + # Particular case where the key is "conversations": we rename it to "messages" + example = { + "conversations": [ + {"from": "user", "value": "What color is the sky?"}, + {"from": "assistant", "value": "It is blue."}, + ] + } + expected_output = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_without_conversations_key(self): + # Same as before, but we don't rename the keys + example = { + "prompt": [{"from": "user", "value": "What color is the sky?"}], + "completion": [{"from": "assistant", "value": "It is blue."}], + } + expected_output = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "completion": [{"role": "assistant", "content": "It is blue."}], + } + self.assertEqual(maybe_convert_to_chatml(example), expected_output) + + def test_not_conversional(self): + # When not needed, the example should remain unchanged + example = {"text": "The sky is blue."} + self.assertEqual(maybe_convert_to_chatml(example), example) + + def test_already_chatml(self): + # When the example is already in ChatML format, it should remain unchanged + example = { + "messages": [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + } + self.assertEqual(maybe_convert_to_chatml(example), example) + + +# Run the tests +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..05397f092da1637d31ec8b9bdec096c74237490d --- /dev/null +++ b/tests/test_dataset_formatting.py @@ -0,0 +1,155 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Callable + +from datasets import Dataset, load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.extras.dataset_formatting import get_formatting_func_from_dataset +from trl.models.utils import ChatMlSpecialTokens, setup_chat_format + + +class DatasetFormattingTestCase(unittest.TestCase): + def setUp(self): + self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1") + self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + def test_get_formatting_func_from_dataset_with_chatml_messages(self): + dataset = Dataset.from_dict( + { + "messages": [ + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + ] + } + ) + + # Llama tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsInstance(formatting_func, Callable) + formatted_text = formatting_func(dataset[0]) + expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" + self.assertEqual(formatted_text, expected) + formatted_text = formatting_func(dataset[0:1]) + self.assertListEqual(formatted_text, [expected]) + + # ChatML tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) + formatted_text = formatting_func(dataset[0]) + expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + self.assertEqual(formatted_text, expected) + formatted_text = formatting_func(dataset[0:1]) + self.assertListEqual(formatted_text, [expected]) + + def test_get_formatting_func_from_dataset_with_chatml_conversations(self): + dataset = Dataset.from_dict( + { + "conversations": [ + [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + ] + } + ) + # Llama tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsInstance(formatting_func, Callable) + formatted_text = formatting_func(dataset[0]) + expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" + self.assertEqual(formatted_text, expected) + formatted_text = formatting_func(dataset[0:1]) + self.assertListEqual(formatted_text, [expected]) + + # ChatML tokenizer + formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) + formatted_text = formatting_func(dataset[0]) + expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" + self.assertEqual(formatted_text, expected) + formatted_text = formatting_func(dataset[0:1]) + self.assertListEqual(formatted_text, [expected]) + + def test_get_formatting_func_from_dataset_with_instruction(self): + dataset = Dataset.from_list( + [{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}] + ) + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsNotNone(formatting_func) + self.assertIsInstance(formatting_func, Callable) + formatted_text = formatting_func(dataset[0]) + self.assertEqual(formatted_text, " [INST] What is 2+2? [/INST] 4") + formatted_text = formatting_func(dataset[0:1]) + self.assertListEqual(formatted_text, [" [INST] What is 2+2? [/INST] 4"]) + + def test_get_formatting_func_from_dataset_from_hub(self): + ds_1 = load_dataset("philschmid/trl-test-instruction", split="train") + ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train") + for ds in [ds_1, ds_2]: + formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer) + self.assertIsNotNone(formatting_func) + self.assertIsInstance(formatting_func, Callable) + ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train") + formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer) + self.assertIsNone(formatting_func) + + def test_get_formatting_func_from_dataset_with_unknown_format(self): + dataset = Dataset.from_dict({"text": "test"}) + formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) + self.assertIsNone(formatting_func) + + +class SetupChatFormatTestCase(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + # remove built-in chat_template to simulate a model having no chat_template + self.tokenizer.chat_template = None + + def test_setup_chat_format(self): + modified_model, modified_tokenizer = setup_chat_format( + self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64 + ) + + _chatml = ChatMlSpecialTokens() + # Check if special tokens are correctly set + self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") + self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>") + self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>") + self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token) + self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token) + self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token) + self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0) + + def test_example_with_setup_model(self): + modified_model, modified_tokenizer = setup_chat_format( + self.model, + self.tokenizer, + ) + messages = [ + {"role": "system", "content": "You are helpful"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi, how can I help you?"}, + ] + prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) + + self.assertEqual( + prompt, + "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + ) diff --git a/tests/test_ddpo_trainer.py b/tests/test_ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5c471093fd0c27f5bb16cd4e78bd8fc4848a157f --- /dev/null +++ b/tests/test_ddpo_trainer.py @@ -0,0 +1,129 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import torch +from transformers.utils import is_peft_available + +from trl.import_utils import is_diffusers_available + +from .testing_utils import require_diffusers + + +if is_diffusers_available() and is_peft_available(): + from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline + + +def scorer_function(images, prompts, metadata): + return torch.randn(1) * 3.0, {} + + +def prompt_function(): + return ("cabbages", {}) + + +@require_diffusers +class DDPOTrainerTester(unittest.TestCase): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.training_args = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False + ) + + self.trainer = DDPOTrainer(self.training_args, scorer_function, prompt_function, pipeline) + + return super().setUp() + + def tearDown(self) -> None: + gc.collect() + + def test_loss(self): + advantage = torch.tensor([-1.0]) + clip_range = 0.0001 + ratio = torch.tensor([1.0]) + loss = self.trainer.loss(advantage, clip_range, ratio) + self.assertEqual(loss.item(), 1.0) + + def test_generate_samples(self): + samples, output_pairs = self.trainer._generate_samples(1, 2) + self.assertEqual(len(samples), 1) + self.assertEqual(len(output_pairs), 1) + self.assertEqual(len(output_pairs[0][0]), 2) + + def test_calculate_loss(self): + samples, _ = self.trainer._generate_samples(1, 2) + sample = samples[0] + + latents = sample["latents"][0, 0].unsqueeze(0) + next_latents = sample["next_latents"][0, 0].unsqueeze(0) + log_probs = sample["log_probs"][0, 0].unsqueeze(0) + timesteps = sample["timesteps"][0, 0].unsqueeze(0) + prompt_embeds = sample["prompt_embeds"] + advantage = torch.tensor([1.0], device=prompt_embeds.device) + + self.assertTupleEqual(latents.shape, (1, 4, 64, 64)) + self.assertTupleEqual(next_latents.shape, (1, 4, 64, 64)) + self.assertTupleEqual(log_probs.shape, (1,)) + self.assertTupleEqual(timesteps.shape, (1,)) + self.assertTupleEqual(prompt_embeds.shape, (2, 77, 32)) + loss, approx_kl, clipfrac = self.trainer.calculate_loss( + latents, timesteps, next_latents, log_probs, advantage, prompt_embeds + ) + + self.assertTrue(torch.isfinite(loss.cpu())) + + +@require_diffusers +class DDPOTrainerWithLoRATester(DDPOTrainerTester): + """ + Test the DDPOTrainer class. + """ + + def setUp(self): + self.training_args = DDPOConfig( + num_epochs=2, + train_gradient_accumulation_steps=1, + per_prompt_stat_tracking_buffer_size=32, + sample_num_batches_per_epoch=2, + sample_batch_size=2, + mixed_precision=None, + save_freq=1000000, + ) + pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" + pretrained_revision = "main" + + pipeline = DefaultDDPOStableDiffusionPipeline( + pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True + ) + + self.trainer = DDPOTrainer(self.training_args, scorer_function, prompt_function, pipeline) + + return super().setUp() diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7dfa70e76785ae58c1ecfa1ea85c5a3610ca3dd --- /dev/null +++ b/tests/test_dpo_trainer.py @@ -0,0 +1,1504 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import tempfile +import unittest +from unittest.mock import MagicMock + +import numpy as np +import torch +from datasets import Dataset, features, load_dataset +from parameterized import parameterized +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForVision2Seq, + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizerBase, + is_vision_available, +) +from transformers.testing_utils import ( + get_device_properties, + require_liger_kernel, + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, + require_vision, +) + +from trl import DPOConfig, DPOTrainer, FDivergenceType + +from .testing_utils import require_bitsandbytes, require_no_wandb + + +if is_vision_available(): + from PIL import Image + + +class TestTokenizeRow(unittest.TestCase): + def setUp(self): + # Set up the mock tokenizer with specific behaviors + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 2 + + # Define mock return values for the tokenizer's 'input_ids' for the different text inputs + self.tokenizer.return_value = { + "input_ids": {"The sky is": [464, 6766, 318], " blue": [4171], " green": [4077]} + } + + # Define tokenizer behavior when called + def mock_tokenizer_call(text, add_special_tokens): + token_map = { + "The sky is": {"input_ids": [464, 6766, 318]}, + " blue": {"input_ids": [4171]}, + " green": {"input_ids": [4077]}, + } + return token_map[text] + + self.tokenizer.side_effect = mock_tokenizer_call + + def test_tokenize_row_no_truncation_no_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with no truncation and no special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=None, + max_completion_length=None, + add_special_tokens=False, + ) + + # Assert the correct output without truncation or special tokens + self.assertEqual( + result, + { + "prompt_input_ids": [464, 6766, 318], + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + }, + ) + + def test_tokenize_row_with_truncation(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with truncation + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=2, + max_completion_length=1, + add_special_tokens=False, + ) + + # Assert the correct output with truncation applied + self.assertEqual( + result, + { + "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + }, + ) + + def test_tokenize_row_with_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=None, + max_completion_length=None, + add_special_tokens=True, + ) + + # Assert the correct output with special tokens added + self.assertEqual( + result, + { + "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + }, + ) + + def test_tokenize_row_with_truncation_and_special_tokens(self): + # Define the input features + features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + + # Call the method with both truncation and special tokens + result = DPOTrainer.tokenize_row( + features=features, + processing_class=self.tokenizer, + max_prompt_length=4, + max_completion_length=1, + add_special_tokens=True, + ) + + # Assert the correct output with both truncation and special tokens + self.assertEqual( + result, + { + "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + }, + ) + + +class DPOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def test_train(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @parameterized.expand( + [ + ("sigmoid",), + ("hinge",), + ("ipo",), + ("exo_pair",), + ("nca_pair",), + ("robust",), + ("bco_pair",), + ("sppo_hard",), + ("aot",), + ("aot_pair",), + ("discopop",), + ("apo_zero",), + ("apo_down",), + ] + ) + def test_train_loss_types(self, loss_type): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + loss_type=loss_type, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + def test_dpo_trainer_with_weighting(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + use_weighting=True, + report_to="none", + ) + + trainer = DPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @parameterized.expand( + [ + (None, "Test when rpo_alpha is set to None"), + (0.5, "Test when rpo_alpha is set to 0.5"), + ] + ) + def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha, _): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + precompute_ref_log_probs=True, + rpo_alpha=rpo_alpha, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + def test_dpo_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + with self.assertRaises(ValueError): + DPOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + + def test_precompute_ref_batch_size(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + precompute_ref_log_probs=True, + precompute_ref_batch_size=4, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @require_peft + def test_dpo_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + precompute_ref_log_probs=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + def test_dpo_trainer_padding_token_is_none(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + tokenizer.pad_token = None + + with self.assertRaisesRegex( + ValueError, + expected_regex=r"`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in " + r"the `processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " + r"`tokenizer.pad_token` \(e.g., `tokenizer.pad_token = tokenizer.eos_token`\) before instantiating " + r"the trainer.", + ): + trainer = DPOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + def test_dpo_trainer_w_dataset_num_proc(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + dataset_num_proc=2, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + trainer = DPOTrainer( + model=self.model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + def test_tr_dpo_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + precompute_ref_log_probs=False, + sync_ref_model=True, + ref_model_mixup_alpha=0.5, + ref_model_sync_steps=1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # params of the ref model as its the same as the model + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.ref_model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + @require_no_wandb + def test_dpo_trainer_generate_during_eval_no_wandb(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + generate_during_eval=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + with self.assertRaisesRegex( + ValueError, + expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): + DPOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + @require_peft + def test_dpo_lora_save(self): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model_peft = get_peft_model(model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + precompute_ref_log_probs=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model_peft, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + try: + AutoModelForCausalLM.from_pretrained(tmp_dir) + except OSError: + self.fail("Loading the saved peft adapter failed") + + @require_peft + @require_torch_gpu_if_bnb_not_multi_backend_enabled + def test_dpo_lora_bf16_autocast_llama(self): + # Note this test only works on compute capability > 7 GPU devices + from peft import LoraConfig + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + bf16=True, + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + @parameterized.expand( + [ + ("sigmoid", False, False), + ("sigmoid", False, True), + ("sigmoid", True, False), + ("sigmoid", True, True), + ("ipo", False, False), + ("ipo", False, True), + ("ipo", True, False), + ("ipo", True, True), + ("aot_pair", False, False), + ("aot_pair", False, True), + ("aot_pair", True, False), + ("aot_pair", True, True), + ("aot", False, False), + ("aot", False, True), + ("aot", True, False), + ("aot", True, True), + ("bco_pair", False, False), + ("bco_pair", False, True), + ("bco_pair", True, False), + ("bco_pair", True, True), + ("robust", False, False), + ("robust", False, True), + ("robust", True, False), + ("robust", True, True), + ] + ) + @require_bitsandbytes + @require_peft + @unittest.skipIf( + get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, + "Skipping because bf16 not supported on CUDA GPU with capability < 8.0", + ) + def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + bf16=True, + beta=0.1, + generate_during_eval=gen_during_eval, + loss_type=loss_type, + precompute_ref_log_probs=pre_compute, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + @require_peft + def test_dpo_lora_tags(self): + from peft import LoraConfig + + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + for tag in ["dpo", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + @require_peft + def test_dpo_tags(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + for tag in ["dpo", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + @require_peft + def test_dpo_lora_force_use_ref(self): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model_peft = get_peft_model(model, lora_config) + + ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + with self.assertRaises(ValueError): + # passing a peft_model as model and ref_model should error out, + # unless you pass `force_use_ref_model` + trainer = DPOTrainer( + model=model_peft, + ref_model=ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + force_use_ref_model=True, + report_to="none", + ) + + trainer = DPOTrainer( + model=model_peft, + ref_model=ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + def test_dpo_trainer_torch_dtype(self): + # See https://github.com/huggingface/trl/issues/1751 + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + model_init_kwargs={"torch_dtype": "float16"}, + ref_model_init_kwargs={"torch_dtype": "float16"}, + report_to="none", + ) + + trainer = DPOTrainer( + model=self.model_id, + ref_model=self.model_id, + processing_class=self.tokenizer, + args=training_args, + train_dataset=dummy_dataset["train"], + ) + self.assertEqual(trainer.model.config.torch_dtype, torch.float16) + self.assertEqual(trainer.ref_model.config.torch_dtype, torch.float16) + + # Now test when `torch_dtype` is provided but is wrong to either the model or the ref_model + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + model_init_kwargs={"torch_dtype": -1}, + report_to="none", + ) + + with self.assertRaises(ValueError) as context: + _ = DPOTrainer( + model=self.model_id, + processing_class=self.tokenizer, + args=training_args, + train_dataset=dummy_dataset["train"], + ) + + self.assertIn( + "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", + str(context.exception), + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=1, + ref_model_init_kwargs={"torch_dtype": -1}, + report_to="none", + ) + + with self.assertRaises(ValueError) as context: + _ = DPOTrainer( + model=self.model_id, + ref_model=self.model_id, + processing_class=self.tokenizer, + args=training_args, + train_dataset=dummy_dataset["train"], + ) + + self.assertIn( + "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing a `torch.dtype` (e.g., 'float32'), but got -1.", + str(context.exception), + ) + + def test_dpo_loss_alpha_div_f(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + f_divergence_type=FDivergenceType.ALPHA_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([810.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([110.6, 0.5]) + losses, _, _ = trainer.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + + def test_dpo_loss_js_div_f(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + + # lora model + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + f_divergence_type=FDivergenceType.JS_DIVERGENCE.value, + f_alpha_divergence_coef=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Fake chosen and rejected log probs + policy_chosen_logps = torch.FloatTensor([410.0, 0.1]) + policy_rejected_logps = torch.FloatTensor([95.5, 0.2]) + reference_chosen_logps = torch.FloatTensor([-610.0, -0.1]) + reference_rejected_logps = torch.FloatTensor([5.5, 0.5]) + losses, _, _ = trainer.dpo_loss( + policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps + ) + self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + + def test_dpo_trainer_use_logits_to_keep(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + use_logits_to_keep=True, + rpo_alpha=0.5, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + # dpo train lora model with a lora config + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + training_args.use_logits_to_keep = False + trainer2 = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Fake batch + prompt_input_ids = torch.randint(1, 1000, (2, 10)) + chosen_input_ids = torch.randint(1, 1000, (2, 5)) + rejected_input_ids = torch.randint(1, 1000, (2, 7)) + prompt_attention_mask = torch.ones_like(prompt_input_ids) + chosen_attention_mask = torch.ones_like(chosen_input_ids) + rejected_attention_mask = torch.ones_like(rejected_input_ids) + + batch = { + "prompt_input_ids": prompt_input_ids.to(model.device), + "chosen_input_ids": chosen_input_ids.to(model.device), + "rejected_input_ids": rejected_input_ids.to(model.device), + "prompt_attention_mask": prompt_attention_mask.to(model.device), + "chosen_attention_mask": chosen_attention_mask.to(model.device), + "rejected_attention_mask": rejected_attention_mask.to(model.device), + } + + output = trainer.concatenated_forward(model, batch) + output2 = trainer2.concatenated_forward(model, batch) + + np.testing.assert_allclose(output["nll_loss"].item(), output2["nll_loss"].item(), atol=1e-5) + np.testing.assert_allclose( + output["mean_chosen_logits"].item(), output2["mean_chosen_logits"].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["mean_rejected_logits"].item(), output2["mean_rejected_logits"].item(), atol=1e-5 + ) + + for i in range(output["chosen_logps"].shape[0]): + np.testing.assert_allclose( + output["chosen_logps"][i].item(), output2["chosen_logps"][i].item(), atol=1e-5 + ) + np.testing.assert_allclose( + output["rejected_logps"][i].item(), output2["rejected_logps"][i].item(), atol=1e-5 + ) + + trainer.train() + + def test_dpo_trainer_with_tools(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Define dummy test tools + def get_current_temperature(location: str): + """ + Gets the temperature at a given location. + + Args: + location: The location to get the temperature for + """ + return 22.0 + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + tools=[get_current_temperature], + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference") + + trainer = DPOTrainer( + model=model, + ref_model=None, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When + # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's + # what we're checking here + self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) + + def test_padding_free(self): + model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + # Normally, we need `attn_implementation="flash_attention_2"` to that the model returns correct logits. + # Without it, the logits may be incorrect, but that's fine here. This test focuses only on the inner logic + # of padding_free. + model = AutoModelForCausalLM.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + learning_rate=9e-1, + per_device_train_batch_size=2, + padding_free=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + do_eval=True, + eval_strategy="steps", + eval_steps=3, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + + def test_train_with_length_desensitization(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + tokenizer = AutoTokenizer.from_pretrained(model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + learning_rate=9e-1, + ld_alpha=0.5, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") + @parameterized.expand([(0.1,), (0.5,)]) + @require_liger_kernel + def test_dpo_trainer_with_liger(self, beta): + """Test DPO trainer with Liger loss enabled. + + This test verifies that: + 1. Training runs successfully with Liger loss + 2. Model parameters update as expected + 3. Loss values are reasonable and finite + 4. Training works with both default and custom beta values + """ + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + do_eval=True, + eval_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=beta, + use_liger_loss=True, # Enable Liger loss + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + trainer = DPOTrainer( + model=self.model, + ref_model=self.ref_model, # Add reference model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + # Store initial parameters + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + train_output = trainer.train() + + # Verify training completed successfully + self.assertIsNotNone(train_output) + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Verify loss is finite + self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + + # Check parameters have been updated + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Only check non-zero parameters + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + # Verify new parameters are finite + self.assertTrue(torch.isfinite(new_param).all()) + + # Verify model can still do forward pass after training + dummy_batch = next(iter(trainer.get_train_dataloader())) + model_inputs = { + "input_ids": dummy_batch["prompt_input_ids"], + "attention_mask": dummy_batch["prompt_attention_mask"], + } + with torch.no_grad(): + output = trainer.model(**model_inputs) + self.assertIsNotNone(output) + self.assertFalse("loss" in output.keys()) + + def test_train_with_iterable_dataset(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset( + "trl-internal-testing/zen", + "standard_preference", + split="train", + streaming=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + max_steps=3, + report_to="none", + ) + trainer = DPOTrainer( + model=model_id, + args=training_args, + processing_class=tokenizer, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + +@require_vision +class DPOVisionTrainerTester(unittest.TestCase): + @parameterized.expand( + [ + ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), + # ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",), + ("trl-internal-testing/tiny-LlavaForConditionalGeneration",), + ("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",), + ] + ) + def test_vdpo_trainer(self, model_id): + # fmt: off + dataset_dict = { + "prompt": [ + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in great detail."}]}], + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Is this bus in the USA?"}]}], + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Give a thorough description of the image."}]}], + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Who are the people in the image?"}]}], + [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is written?"}]}], + ], + "chosen": [ + [{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, multi-colored train."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "Yes, it can be assumed that this bus is in the USA."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "There are two individuals, possibly girls or women."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}], + ], + "rejected": [ + [{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, colorful train."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "No, it's not in the USA."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path surrounded by trees."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": "In the image, there are two individuals."}]}], + [{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}], + ], + "images": [ + [Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))], + [Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))], + [Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))], + [Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))], + [Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))], + ], + } + # fmt: on + dataset = Dataset.from_dict(dataset_dict) + dataset = dataset.cast_column("images", features.Sequence(features.Image())) + + # Instantiate the model and processor + model = AutoModelForVision2Seq.from_pretrained(model_id) + ref_model = AutoModelForVision2Seq.from_pretrained(model_id) + processor = AutoProcessor.from_pretrained(model_id) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + remove_unused_columns=False, + learning_rate=0.01, # increase learning rate to speed up test + max_prompt_length=None, # don't truncate to avoid issues with patch tokens + max_length=None, + report_to="none", + ) + trainer = DPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=processor, + train_dataset=dataset, + eval_dataset=dataset, + ) + + # Save the initial weights, so we can check if they have changed after training + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the trainable params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + if model_id in [ + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", + ] and ( + "vision_tower.vision_model.encoder.layers.1" in n + or "vision_tower.vision_model.post_layernorm.weight" in n + ): + # For some reason, these params are not updated. This is probably not related to TRL, but to + # the model itself. We should investigate this further, but for now we just skip these params. + continue + self.assertFalse( + torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_environments.py b/tests/test_environments.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec6107ef637085072c03561773652b285df95fe --- /dev/null +++ b/tests/test_environments.py @@ -0,0 +1,278 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory + + +class DummyTool: + def __call__(self, text): + return text + + +def dummy_generate(histories): + for i in range(len(histories)): + histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) + return histories + + +class TextHistoryTest(unittest.TestCase): + def test_text_history_init(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + self.assertEqual(history.text, text) + self.assertTrue(torch.equal(history.tokens, tokens)) + self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens))) + + history = TextHistory(text, tokens, system=False) + self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens))) + + def test_text_history_append_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + self.assertEqual(history.text, (text + "General Kenobi!")) + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))) + + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.text, ((text + "General Kenobi!") + "You are a bold one!")) + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))) + + def test_text_history_complete(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.complete() + self.assertTrue(history.completed) + self.assertFalse(history.truncated) + + history.complete(truncated=True) + self.assertTrue(history.completed) + self.assertTrue(history.truncated) + + def test_text_history_last_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.last_text_segment, "You are a bold one!") + + def test_text_history_split_query_response(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True) + query, response, mask = history.split_query_response_tokens() + + self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3]))) + self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) + + +class TextEnvironmentTester(unittest.TestCase): + def setUp(self): + # model_id + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + + # get models and tokenizer + self.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.model_id) + self.gpt2_tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.gpt2_tokenizer.pad_token = self.gpt2_tokenizer.eos_token + + def test_text_environment_setup(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + self.assertEqual(env.prompt, "I am a prompt!\n") + self.assertListEqual(list(env.tools.keys()), ["DummyTool"]) + self.assertIsInstance(env.tools["DummyTool"], DummyTool) + self.assertEqual(env.reward_fn("Hello there!"), 1) + + def test_text_environment_generate(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = env._generate_batched(model_inputs, batch_size=2) + generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) + + generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs] + generations_single = self.gpt2_tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_text_environment_tool_call_parsing(self): + string_valid = "Something something Hello there!" + string_invalid_request = "Something something Hello there!" + string_invalid_call = "Something something Hello there!" + string_invalid_tool = "Something something |Tool2|Hello there!" + string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>" + + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + tool, response = env.parse_tool_call(string_valid) + self.assertEqual(tool, "Tool1") + self.assertEqual(response, "Hello there!") + + tool, response = env.parse_tool_call(string_invalid_request) + self.assertIsNone(tool) + self.assertIsNone(response) + + tool, response = env.parse_tool_call(string_invalid_call) + self.assertIsNone(tool) + self.assertIsNone(response) + + tool, response = env.parse_tool_call(string_invalid_tool) + self.assertIsNone(tool) + self.assertIsNone(response) + + tool, response = env.parse_tool_call(string_invalid_random) + self.assertIsNone(tool) + self.assertIsNone(response) + + def test_text_environment_tool_truncation(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"dummy": lambda x: "a" * 1000}, + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + + env.max_tool_response = 100 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 100) + + env.max_tool_response = 500 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 500) + + env.max_tool_response = 1001 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 1000) + + env.max_tool_response = 2000 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual((len(history.last_text_segment) - len(env.response_token)), 1000) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_max_calls(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(1) for _ in x], + prompt="I am a prompt!\n", + ) + + env.max_turns = 1 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, + ("I am a prompt!\n" + "test") + (1 * "testtest"), + ) + + env.max_turns = 2 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, + ("I am a prompt!\n" + "test") + (2 * "testtest"), + ) + + env.max_turns = 4 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, + ("I am a prompt!\n" + "test") + (4 * "testtest"), + ) + + def test_text_environment_compute_rewards(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + ) + + histories = [TextHistory("test", torch.tensor([1, 2, 3])) for _ in range(8)] + histories = env.compute_reward(histories) + + for i in range(8): + self.assertEqual(histories[i].reward, i) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_run(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + task_1 = "Hello there!" + task_2 = "Hello there! General Kenobi!" + + query, response, response_mask, reward, histories = env.run([task_1, task_2]) + self.assertEqual(len(query[0]), 8) + self.assertEqual(len(query[1]), 12) + self.assertEqual(len(response[0]), 14) + self.assertEqual(len(response[1]), 14) + self.assertEqual(response_mask[0].sum(), (2 * 3)) + # mocked generate always adds 3 toknes + self.assertEqual(response_mask[1].sum(), (2 * 3)) + # mocked generate always adds 3 toknes + self.assertEqual(reward[1], 1) + self.assertEqual( + histories[0].text, + ("I am a prompt!\n" + "Hello there!") + (2 * "testtest"), + ) + self.assertEqual( + histories[1].text, + ("I am a prompt!\n" + "Hello there! General Kenobi!") + + (2 * "testtest"), + ) diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..98de9e0fd60b15e62867254f052a8315b388cb4c --- /dev/null +++ b/tests/test_gkd_trainer.py @@ -0,0 +1,264 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch +import torch.nn.functional as F +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + +from trl import GKDConfig, GKDTrainer +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +class TestGKDTrainer(unittest.TestCase): + @classmethod + def setUpClass(cls): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.model = AutoModelForCausalLM.from_pretrained(model_id) + cls.generation_config = GenerationConfig( + max_new_tokens=20, + num_return_sequences=1, + pad_token_id=cls.tokenizer.pad_token_id, + eos_token_id=cls.tokenizer.eos_token_id, + ) + + def test_generate_on_policy_outputs_deterministic(self): + prompts = ["Hello, how are you?", "What's the weather like today?"] + tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True) + + inputs = { + "prompts": tokenized_prompts["input_ids"], + "prompt_attention_mask": tokenized_prompts["attention_mask"], + } + + # Set temperature to 0 for deterministic output + deterministic_generation_config = GenerationConfig( + max_new_tokens=30, + num_return_sequences=1, + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + temperature=0.0, + ) + + outputs = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, deterministic_generation_config, self.tokenizer.pad_token_id + ) + + new_input_ids, new_attention_mask, new_labels = outputs + + # Decode the generated outputs + generated_texts = self.tokenizer.batch_decode(new_input_ids, skip_special_tokens=True) + + # Check if the generated texts start with the original prompts + for prompt, generated_text in zip(prompts, generated_texts): + self.assertTrue( + generated_text.startswith(prompt), + f"Generated text '{generated_text}' does not start with prompt '{prompt}'", + ) + + # Run the generation twice and check if the outputs are identical + outputs2 = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, deterministic_generation_config, self.tokenizer.pad_token_id + ) + + new_input_ids2, new_attention_mask2, new_labels2 = outputs2 + + # Check if the two generations are identical + self.assertTrue(torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical") + self.assertTrue( + torch.all(new_attention_mask.eq(new_attention_mask2)), + "Attention masks for deterministic generations are not identical", + ) + self.assertTrue( + torch.all(new_labels.eq(new_labels2)), + "Labels for deterministic generations are not identical", + ) + + def test_generate_on_policy_outputs(self): + prompts = ["Hello, how are you?", "What's the weather like today?"] + tokenized_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True) + + inputs = { + "prompts": tokenized_prompts["input_ids"], + "attention_mask": tokenized_prompts["attention_mask"], + } + + outputs = GKDTrainer.generate_on_policy_outputs( + self.model, inputs, self.generation_config, self.tokenizer.pad_token_id + ) + + # Check that outputs is a tuple of three tensors + self.assertIsInstance(outputs, tuple) + self.assertEqual(len(outputs), 3) + + new_input_ids, new_attention_mask, new_labels = outputs + + # Check shapes + batch_size = len(prompts) + self.assertEqual(new_input_ids.shape[0], batch_size) + self.assertEqual(new_attention_mask.shape[0], batch_size) + self.assertEqual(new_labels.shape[0], batch_size) + + # Check types + self.assertIsInstance(new_input_ids, torch.Tensor) + self.assertIsInstance(new_attention_mask, torch.Tensor) + self.assertIsInstance(new_labels, torch.Tensor) + + # Check that new_input_ids and new_attention_mask have the same shape + self.assertEqual(new_input_ids.shape, new_attention_mask.shape) + self.assertEqual(new_labels.shape, new_attention_mask.shape) + + +class TestGeneralizedJSDLoss(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.seq_length = 3 + self.vocab_size = 5 + self.student_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + self.teacher_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + + def test_uniform_distribution(self): + logits = torch.ones(1, 1, self.vocab_size) + loss = GKDTrainer.generalized_jsd_loss(logits, logits) + self.assertAlmostEqual(loss.item(), 0, places=5) + + def test_generalized_jsd_loss_edge_cases(self): + # Setup + student_logits = torch.log(torch.tensor([[0.1, 0.9]])).unsqueeze(0) + teacher_logits = torch.log(torch.tensor([[0.9, 0.1]])).unsqueeze(0) + + # Case 1: beta = 1 (should be equivalent to KL(student || teacher)) + loss_beta_1 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=1) + expected_loss_beta_1 = F.kl_div( + F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean" + ) + self.assertAlmostEqual(loss_beta_1.item(), expected_loss_beta_1.item(), places=5) + + # Case 2: beta = 0 (should be equivalent to KL(teacher || student)) + loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0) + expected_loss_beta_0 = F.kl_div( + F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean" + ) + self.assertAlmostEqual(loss_beta_0.item(), expected_loss_beta_0.item(), places=5) + + def test_output_shape(self): + loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits) + self.assertTrue(torch.is_tensor(loss)) + self.assertEqual(loss.shape, torch.Size([])) + + def test_beta_values(self): + loss_beta_0 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0) + loss_beta_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=1) + self.assertNotEqual(loss_beta_0, loss_beta_1) + + def test_temperature_scaling(self): + loss_temp_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=1) + loss_temp_2 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=2) + self.assertNotEqual(loss_temp_1, loss_temp_2) + + def test_reduction_methods(self): + loss_batchmean = GKDTrainer.generalized_jsd_loss( + self.student_logits, self.teacher_logits, reduction="batchmean" + ) + loss_sum = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="sum") + loss_mean = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="mean") + loss_none = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="none") + + self.assertEqual(loss_batchmean.shape, torch.Size([])) + self.assertEqual(loss_sum.shape, torch.Size([])) + self.assertEqual(loss_mean.shape, torch.Size([])) + self.assertEqual(loss_none.shape, self.student_logits.shape) + + def test_symmetry(self): + student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.1) + teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.1) + self.assertNotEqual(student_teacher, teacher_student) + + student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.5) + teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.5) + self.assertEqual(student_teacher, teacher_student) + + def test_zero_loss_for_identical_inputs(self): + identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) + loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits) + self.assertAlmostEqual(loss.item(), 0, places=6) + + +class GKDTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.teacher_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Ensure the tokenizer has a chat template + if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None: + self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + def test_gkd_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GKDConfig( + output_dir=tmp_dir, + dataloader_drop_last=True, + eval_strategy="steps", + max_steps=4, + eval_steps=2, + save_steps=2, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = GKDTrainer( + model=self.model_id, + teacher_model=self.model_id, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2")) + + def test_generation_config_init(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GKDConfig(output_dir=tmp_dir) + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + + trainer = GKDTrainer( + model=self.model_id, + teacher_model=self.model_id, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + + self.assertEqual(trainer.generation_config.pad_token_id, self.tokenizer.eos_token_id) + self.assertEqual(trainer.generation_config.eos_token_id, self.model.generation_config.eos_token_id) + self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens) + self.assertEqual(trainer.generation_config.temperature, training_args.temperature) + self.assertEqual(trainer.generation_config.top_k, 0) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..dbeb571792efefcd2568310f44d2d920866109f0 --- /dev/null +++ b/tests/test_grpo_trainer.py @@ -0,0 +1,1211 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from unittest.mock import patch + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import GRPOConfig, GRPOTrainer +from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict + +from .testing_utils import require_vllm + + +if is_peft_available(): + from peft import LoraConfig, PeftModel + + +class SplitTensorDictTester(unittest.TestCase): + def test_split_equal_chunks(self): + x = torch.arange(12).reshape(6, 2) + y = torch.arange(6).reshape(6, 1) + tensor_dict = {"x": x, "y": y} + + result = split_tensor_dict(tensor_dict, 3) + + expected_x_chunks = torch.chunk(x, 3, dim=0) + expected_y_chunks = torch.chunk(y, 3, dim=0) + self.assertEqual(len(result), 3) + for i in range(3): + self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) + self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i])) + + def test_with_none_tensor(self): + x = torch.arange(12).reshape(6, 2) + tensor_dict = {"x": x, "y": None} + + result = split_tensor_dict(tensor_dict, 2) + + expected_x_chunks = torch.chunk(x, 2, dim=0) + self.assertEqual(len(result), 2) + for i in range(2): + self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) + self.assertIsNone(result[i]["y"]) + + +class ShuffleTensorDictTester(unittest.TestCase): + def test_shuffle_preserves_shape(self): + x = torch.arange(6).reshape(3, 2) + y = torch.arange(3).reshape(3, 1) + tensor_dict = {"x": x.clone(), "y": y.clone()} + + shuffled = shuffle_tensor_dict(tensor_dict) + + self.assertEqual(shuffled["x"].shape, x.shape) + self.assertEqual(shuffled["y"].shape, y.shape) + + def test_shuffle_consistent_across_tensors(self): + # Use known patterns to check alignment + x = torch.tensor([[10, 11], [20, 21], [30, 31]]) + y = torch.tensor([[1], [2], [3]]) + tensor_dict = {"x": x.clone(), "y": y.clone()} + + shuffled = shuffle_tensor_dict(tensor_dict) + + # Build a reverse map from shuffled x rows to y values + for i in range(3): + x_row = shuffled["x"][i] + y_val = shuffled["y"][i].item() + + if torch.equal(x_row, torch.tensor([10, 11])): + self.assertEqual(y_val, 1) + elif torch.equal(x_row, torch.tensor([20, 21])): + self.assertEqual(y_val, 2) + elif torch.equal(x_row, torch.tensor([30, 31])): + self.assertEqual(y_val, 3) + else: + self.fail("Unexpected x row in shuffled output.") + + def test_none_tensor_remains_none(self): + x = torch.arange(6).reshape(3, 2) + tensor_dict = {"x": x.clone(), "y": None} + + shuffled = shuffle_tensor_dict(tensor_dict) + + self.assertIsNone(shuffled["y"]) + self.assertEqual(shuffled["x"].shape, x.shape) + + +class RepeatRandomSamplerTester(unittest.TestCase): + def test_sampler(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=2) + # Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5] + sampled = list(sampler) + # Check that the length is doubled + assert len(sampled) == 2 * len(dataset) + # Check that all indexes are present + assert set(sampled) == set(range(len(dataset))) + # Check that each element is repeated twice + assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) + + def test_sampler_no_shuffle(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False) + sampled = list(sampler) + expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] + self.assertEqual(sampled, expected) + + def test_sampler_no_repeat(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=1) + # Should output something like [4, 3, 0, 1, 2, 6, 5] + sampled = list(sampler) + # Check that the length is the same + assert len(sampled) == len(dataset) + # Check that all indexes are present + assert set(sampled) == set(range(len(dataset))) + + def test_sampler_with_batch_size(self): + dataset = ["a", "b", "c", "d", "e", "f", "g", "h"] + sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) + # Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7] + sampled = list(sampler) + # Check that the length is doubled + assert len(sampled) == 2 * len(dataset) + # Check that all indexes are present + assert set(sampled) == set(range(len(dataset))) + # Check that each element is repeated as expected + assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) + + def test_sampler_with_batch_size_and_drop(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2) + # Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6] + sampled = list(sampler) + # Check that the length is doubled + assert len(sampled) == 2 * ( + len(dataset) - 1 + ) # one element is dropped, because it's not enough to form a batch + # Check that the sampled indexes are a subset of the dataset indexes + assert set(sampled).issubset(set(range(len(dataset)))) + # Check that each element is repeated as expected + assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4)) + + def test_sampler_with_mini_repeat_count_and_batch_size_1(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2) + # Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0, + # 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6] + sampled = list(sampler) + # Check that the length is quadrupled + assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch + # Check that the sampled indexes are a subset of the dataset indexes + assert set(sampled).issubset(set(range(len(dataset)))) + # Check that each element is repeated as expected + assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) + # Check that the batch is repeated as expected + assert sampled[0:6] == sampled[6:12] + assert sampled[12:18] == sampled[18:24] + + def test_sampler_with_mini_repeat_count_and_batch_size_2(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2) + # Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3, + # 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + # 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6] + sampled = list(sampler) + # Check that the length is sextupled + assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch + # Check that the sampled indexes are a subset of the dataset indexes + assert set(sampled).issubset(set(range(len(dataset)))) + # Check that each element is repeated as expected + assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3)) + # Check that the batch is repeated as expected + assert sampled[0:6] == sampled[6:12] + assert sampled[12:18] == sampled[18:24] + assert sampled[24:30] == sampled[30:36] + + def test_sampler_with_mini_repeat_count_and_batch_size_3(self): + dataset = ["a", "b", "c", "d", "e", "f", "g"] + sampler = RepeatSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3) + # Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3, + # 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, + # 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6] + sampled = list(sampler) + # Check that the length is sextupled + assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch + # Check that the sampled indexes are a subset of the dataset indexes + assert set(sampled).issubset(set(range(len(dataset)))) + # Check that each element is repeated as expected + assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2)) + # Check that the batch is repeated as expected + assert sampled[0:4] == sampled[4:8] == sampled[8:12] + assert sampled[12:16] == sampled[16:20] == sampled[20:24] + assert sampled[24:28] == sampled[28:32] == sampled[32:36] + + +class GRPOTrainerTester(unittest.TestCase): + def test_init_minimal(self): + # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + train_dataset=dataset, + ) + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training(self, config_name): + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @parameterized.expand([("bnpo",), ("dr_grpo",)]) + def test_training_loss_types(self, loss_type): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=32, # reduce the completion length to reduce memory usage + loss_type=loss_type, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_eval(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + eval_strategy="steps", + eval_steps=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + trainer.train() + + def test_training_multiple_iterations(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_iterations=2, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_peft + def test_training_peft(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + + @require_peft + def test_training_peft_with_gradient_checkpointing(self): + """Test that training works with PEFT and gradient checkpointing enabled.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + model = AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues + use_cache=False, # Required for gradient checkpointing + ) + + lora_config = LoraConfig( + r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none" + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + gradient_checkpointing=True, # Enable gradient checkpointing + report_to="none", + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + ) + + # Verify gradient checkpointing is enabled + self.assertIsInstance(trainer.model, PeftModel) + + # Store initial parameters to check which ones change + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that only LoRA parameters have changed, base model parameters remain unchanged + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "lora" in n.lower(): # LoRA parameters should change + self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + else: # Base model parameters should not change + self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") + + def test_training_different_reward_model(self): + # Use a reward model different from the model: different chat template, tokenization, etc. + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2" + reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id) + reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id) + # By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token + # appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation + # considers the score of the last token before the first pad token. To ensure correct reward calculations, + # we use a separate pad token instead. + reward_tokenizer.pad_token = "<|finetune_right_pad_id|>" + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset, + reward_processing_classes=reward_tokenizer, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_standard(self): + # Test if trainer can handle reward function with standard format + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_conversational(self): + # Test if trainer can handle reward function with conversational format + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that gives higher scores to longer completion content.""" + completion_contents = [completion[0]["content"] for completion in completions] + return [float(len(content)) for content in completion_contents] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs(self): + # Test that GRPOTrainer can be instantiated with multiple reward functions + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs_with_None_output(self): + """Test that a valid math reward function is processed correctly while the code reward function returns None.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def applicable_reward_func(completions, **kwargs): + """A reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def non_applicable_reward_func(completions, **kwargs): + """A reward function that returns None for all inputs, as it is not applicable to this sample.""" + return [None] * len(completions) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + report_to="none", + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[ + applicable_reward_func, + non_applicable_reward_func, + ], # One applicable, one non applicable + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = { + n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad + } + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_reward_funcs_with_weights(self): + """Test that GRPOTrainer can handle multiple reward functions with weights.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func1(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + def reward_func2(completions, **kwargs): + """Reward function that rewards completions with more unique letters.""" + return [float(len(set(completion))) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + reward_weights=[0.7, 0.3], # weight of reward_func1 and reward_func2 respectively + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func1, reward_func2], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + # Check that training logs contain both reward metrics + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) + self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) + self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) + self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_mixed_reward_funcs(self): + # Test if the trainer can handle a mix of reward functions and reward models + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion)) for completion in completions] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"], + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_reward_func_additional_column(self): + # Test if trainer can handle reward function that rely on additional columns in the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + # Add a column to the dataset (dummy example, the column could be anything) + some_values = list(range(len(dataset))) + dataset = dataset.add_column("some_values", some_values) + + def reward_func(completions, some_values, **kwargs): + """Reward function that rewards completions with lengths closer to the values in some_values.""" + return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)] + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vllm + @unittest.skip("We should add a mock for the vLLM server.") + def test_training_vllm(self): + """Test that training works with vLLM for generation.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_sync_ref_model(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + sync_ref_model=True, + ref_model_sync_steps=2, # reduce sync steps to ensure a sync happens + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_beta_non_zero(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + beta=0.1, # set beta to non-zero value to test the case where the reference model is used + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @unittest.skip("We should add a mock for the vLLM server.") + @require_peft + @require_vllm + def test_training_vllm_and_peft(self): + """Test that training works with vLLM for generation.""" + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + ) + lora_config = LoraConfig( + target_modules="all-linear", + # test with non-default modules as it add extra keys in state_dict tht we need to handle + modules_to_save=["embed_tokens", "lm_head"], + ) + trainer = GRPOTrainer( + model=model, + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + elif "base_layer" not in n and "original_module" not in n: + # We expect the peft params to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + + @require_vllm + @unittest.skip("We should add a mock for the vLLM server.") + def test_training_vllm_guided_decoding(self): + """Test that training works with vLLM for generation with guided decoding.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + vllm_guided_decoding_regex=r"\n.*\n\n\n.*\n", + ) + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_additional_generation_kwargs(self): + """Test that training works with additional generation kwargs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + top_p=0.9, + top_k=10, + min_p=0.01, + repetition_penalty=1.1, + ) + + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vllm + @unittest.skip("We should add a mock for the vLLM server.") + def test_training_vllm_with_additional_generation_kwargs(self): + """Test that training works with vLLM and additional generation kwargs.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_vllm=True, + top_p=0.9, + top_k=10, + min_p=0.01, + repetition_penalty=1.1, + ) + + trainer = GRPOTrainer( + model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_no_scale_rewards(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + scale_rewards=False, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @patch("transformers.generation.utils.GenerationMixin.generate") + def test_training_with_mask_truncated_completions(self, mock_generate): + """Test that training works with mask_truncated_completions=True parameter.""" + + # We mock the generate method because the model's random weights make it extremely unlikely to produce a + # sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are + # masked in the loss, the model doesn't update, and the final check (which verifies the update) fails. + def fake_generate(prompt_ids, **kwargs): + # pad_token_id = 151643; eos_token_id = 151645 + completions_ids = torch.tensor( + [ + [1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated + [9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos + [12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit + ], + device=prompt_ids.device, + ) + return torch.cat([prompt_ids, completions_ids], dim=1) + + mock_generate.side_effect = fake_generate + + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + mask_truncated_completions=True, # Enable masking of truncated completions + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_with_mask_truncated_completions_all_masked(self): + """ + Test that when all generated completions are truncated (i.e., none contain an EOS token), and + mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update + its parameters. + + Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating + the EOS token is extremely low, so all generated completions are truncated. + """ + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + mask_truncated_completions=True, # Enable masking of truncated completions + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") + + def test_training_num_generations_larger_than_batch_size(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_generations=6, # the number of generations is larger than the batch size, but + gradient_accumulation_steps=2, # gradient accumulation should allow that + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_delta_clipping(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + delta=2.0, # set delta to a non-None value + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_training_multiple_dataloader_workers(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + dataloader_num_workers=2, # use multiple dataloader workers + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/tests/test_iterative_sft_trainer.py b/tests/test_iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c79f04cf2d22f79ce9319ff1ead3a6d6d789886e --- /dev/null +++ b/tests/test_iterative_sft_trainer.py @@ -0,0 +1,117 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from functools import partial + +import torch +from datasets import Dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, TrainingArguments + +from trl import IterativeSFTTrainer + + +class IterativeTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + def _init_tensor_dummy_dataset(self): + dummy_dataset_dict = { + "input_ids": [ + torch.tensor([5303, 3621, 3666, 1438, 318]), + torch.tensor([3666, 1438, 318, 3666, 1438, 318]), + torch.tensor([5303, 3621, 3666, 1438, 318]), + ], + "attention_mask": [ + torch.tensor([1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1, 1]), + torch.tensor([1, 1, 1, 1, 1]), + ], + "labels": [ + torch.tensor([5303, 3621, 3666, 1438, 318]), + torch.tensor([3666, 1438, 318, 3666, 1438, 318]), + torch.tensor([5303, 3621, 3666, 1438, 318]), + ], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + def _init_textual_dummy_dataset(self): + dummy_dataset_dict = { + "texts": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + "texts_labels": ["Testing the IterativeSFTTrainer.", "This is a test of the IterativeSFTTrainer"], + } + + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + dummy_dataset.set_format("torch") + return dummy_dataset + + @parameterized.expand( + [ + ["qwen", "tensor"], + ["qwen", "text"], + ["t5", "tensor"], + ["t5", "text"], + ] + ) + def test_iterative_step_from_tensor(self, model_name, input_name): + with tempfile.TemporaryDirectory() as tmp_dir: + # initialize dataset + if input_name == "tensor": + dummy_dataset = self._init_tensor_dummy_dataset() + inputs = { + "input_ids": dummy_dataset["input_ids"], + "attention_mask": dummy_dataset["attention_mask"], + "labels": dummy_dataset["labels"], + } + else: + dummy_dataset = self._init_textual_dummy_dataset() + inputs = { + "texts": dummy_dataset["texts"], + "texts_labels": dummy_dataset["texts_labels"], + } + + if model_name == "qwen": + model = self.model + tokenizer = self.tokenizer + else: + model = self.t5_model + tokenizer = self.t5_tokenizer + + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=2, + learning_rate=1e-3, + report_to="none", + ) + iterative_trainer = IterativeSFTTrainer(model=model, args=training_args, processing_class=tokenizer) + iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False) + + iterative_trainer.step(**inputs) + + for param in iterative_trainer.model.parameters(): + self.assertIsNotNone(param.grad) diff --git a/tests/test_judges.py b/tests/test_judges.py new file mode 100644 index 0000000000000000000000000000000000000000..4849e22ded572d5a0f45dcddb241a8c6a724765e --- /dev/null +++ b/tests/test_judges.py @@ -0,0 +1,76 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest + +from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge + +from .testing_utils import RandomBinaryJudge, require_llm_blender + + +class TestJudges(unittest.TestCase): + def _get_prompts_and_pairwise_completions(self): + prompts = ["The capital of France is", "The biggest planet in the solar system is"] + completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] + return prompts, completions + + def _get_prompts_and_single_completions(self): + prompts = ["What's the capital of France?", "What's the color of the sky?"] + completions = ["Marseille", "blue"] + return prompts, completions + + def test_all_true_judge(self): + judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) + prompts, completions = self._get_prompts_and_single_completions() + judgements = judge.judge(prompts=prompts, completions=completions) + self.assertEqual(len(judgements), 2) + self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) + + @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") + def test_hugging_face_judge(self): + judge = HfPairwiseJudge() + prompts, completions = self._get_prompts_and_pairwise_completions() + ranks = judge.judge(prompts=prompts, completions=completions) + self.assertEqual(len(ranks), 2) + self.assertTrue(all(isinstance(rank, int) for rank in ranks)) + self.assertEqual(ranks, [0, 1]) + + def load_pair_rm_judge(self): + # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. + # This is a workaround to retry loading the model a few times. + for _ in range(5): + try: + return PairRMJudge() + except ValueError: + time.sleep(5) + raise ValueError("Failed to load PairRMJudge") + + @require_llm_blender + def test_pair_rm_judge(self): + judge = self.load_pair_rm_judge() + prompts, completions = self._get_prompts_and_pairwise_completions() + ranks = judge.judge(prompts=prompts, completions=completions) + self.assertEqual(len(ranks), 2) + self.assertTrue(all(isinstance(rank, int) for rank in ranks)) + self.assertEqual(ranks, [0, 1]) + + @require_llm_blender + def test_pair_rm_judge_return_scores(self): + judge = self.load_pair_rm_judge() + prompts, completions = self._get_prompts_and_pairwise_completions() + probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) + self.assertEqual(len(probs), 2) + self.assertTrue(all(isinstance(prob, float) for prob in probs)) + self.assertTrue(all(0 <= prob <= 1 for prob in probs)) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..140467fc87b05aa5b97492e6526b873a7d476626 --- /dev/null +++ b/tests/test_kto_trainer.py @@ -0,0 +1,449 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.testing_utils import require_liger_kernel, require_peft + +from trl import KTOConfig, KTOTrainer +from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize + +from .testing_utils import require_no_wandb + + +class KTOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + + @parameterized.expand( + [ + ("qwen", "standard_preference", "kto", True, True), + # ("t5", "standard_implicit_prompt_preference", "kto", True, False), # KTO broken for enc-dec + ("qwen", "standard_unpaired_preference", "kto", False, True), + # ("t5", "conversational_preference", "kto", False, False), + ("qwen", "conversational_implicit_prompt_preference", "apo_zero_unpaired", True, True), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", True, False), + ("qwen", "standard_unpaired_preference", "apo_zero_unpaired", False, True), + # ("t5", "conversational_unpaired_preference", "apo_zero_unpaired", False, False), + ] + ) + def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_dataset): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps" if eval_dataset else "no", + beta=0.1, + precompute_ref_log_probs=pre_compute, + loss_type=loss_type, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + ref_model = self.ref_model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + ref_model = self.t5_ref_model + tokenizer = self.t5_tokenizer + + trainer = KTOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"] if eval_dataset else None, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + def test_kto_trainer_with_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with self.assertRaises(ValueError): + KTOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + + def test_tokenize_and_process_tokens(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=self.ref_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + train_dataset = dummy_dataset["train"] + tokenized_dataset = train_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": trainer.tokenizer}, + batched=True, + batch_size=2, + ) + self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"]) + self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"]) + self.assertListEqual(tokenized_dataset["label"], train_dataset["label"]) + self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) + self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) + self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) + self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + + # Test corruption of (prompt, completion) pairs for KL dataset + for batch_size in [2, 3]: + tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=batch_size) + + # Verify that the "answer_input_ids" have been modified, meaning the new "answer_input_ids" differ + # from the original ones. However, when the length of the dataset modulo batch_size equals 1, + # the last batch remains unaltered. This is a rare scenario that does not impact the training + # process, so we exclude it from testing by iterating only up to len - 1. + for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): + self.assertListEqual( + tokenized_dataset["prompt_input_ids"][i], + tokenized_kl_dataset["prompt_input_ids"][i], + ) + self.assertListEqual( + tokenized_dataset["prompt_attention_mask"][i], + tokenized_kl_dataset["prompt_attention_mask"][i], + ) + self.assertNotEqual( + tokenized_dataset["answer_input_ids"][i], + tokenized_kl_dataset["answer_input_ids"][i], + ) + + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": trainer.is_encoder_decoder, + "tokenizer": trainer.tokenizer, + "max_length": trainer.max_length, + "truncation_mode": trainer.truncation_mode, + "label_pad_token_id": trainer.label_pad_token_id, + "max_prompt_length": trainer.max_prompt_length, + } + processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) + self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"]) + self.assertListEqual(processed_dataset["completion"], train_dataset["completion"]) + self.assertListEqual(processed_dataset["label"], train_dataset["label"]) + self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) + self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) + self.assertListEqual( + processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645] + ) + self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) + self.assertListEqual( + processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645] + ) + + def test_kto_trainer_without_providing_ref_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + @require_peft + def test_kto_trainer_without_providing_ref_model_with_lora(self): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + @require_no_wandb + def test_kto_trainer_generate_during_eval_no_wandb(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + generate_during_eval=True, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + with self.assertRaisesRegex( + ValueError, + expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve.", + ): + KTOTrainer( + model=self.model, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + @require_peft + def test_kto_lora_save(self): + from peft import LoraConfig, get_peft_model + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # lora model + model = AutoModelForCausalLM.from_pretrained(self.model_id) + model_peft = get_peft_model(model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + # kto train lora model with a lora config + trainer = KTOTrainer( + model=model_peft, + ref_model=None, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + # train the model + trainer.train() + + # save peft adapter + trainer.save_model() + + # assert that the model is loaded without giving OSError + try: + AutoModelForCausalLM.from_pretrained(tmp_dir) + except OSError: + self.fail("Loading the saved peft adapter failed") + + @require_liger_kernel + def test_kto_trainer_with_liger(self): + """Test KTO trainer with Liger loss enabled.""" + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + report_to="none", + use_liger_loss=True, # Enable Liger loss + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + trainer = KTOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # check the params have changed - ignore 0 biases + if param.sum() != 0: + self.assertFalse(torch.equal(param, new_param)) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, + per_device_train_batch_size=2, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = KTOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..abfcdd6349407ecf96afceded88d03a8737f9250 --- /dev/null +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -0,0 +1,67 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import AutoModelForCausalLM, GenerationConfig + +from trl.models.modeling_base import GeometricMixtureWrapper, create_reference_model + + +class TestGeometricMixtureWrapper(unittest.TestCase): + def setUp(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(model_id) + self.ref_model = create_reference_model(self.model) + self.generation_config = GenerationConfig.from_pretrained(model_id) + self.mixture_coef = 0.5 + self.wrapper = GeometricMixtureWrapper( + self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef + ) + + def test_forward(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + self.assertIsNotNone(output) + self.assertTrue(hasattr(output, "logits")) + self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) + + def test_mixture_coefficient(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + ref_model_output = self.ref_model(input_ids=input_ids, attention_mask=attention_mask) + wrapper_output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + expected_logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 + ) + + self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) + + def test_prepare_inputs_for_generation(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) + + self.assertIn("input_ids", inputs) + self.assertIn("attention_mask", inputs) + self.assertFalse(inputs.get("use_cache", False)) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..539cdf2fbfe63ca01b4a0a690dd2611efaabfe76 --- /dev/null +++ b/tests/test_modeling_value_head.py @@ -0,0 +1,504 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import torch +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig + +from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, create_reference_model + + +ALL_CAUSAL_LM_MODELS = [ + "trl-internal-testing/tiny-BloomForCausalLM", + "trl-internal-testing/tiny-CohereForCausalLM", + "trl-internal-testing/tiny-DbrxForCausalLM", + "trl-internal-testing/tiny-FalconMambaForCausalLM", + "trl-internal-testing/tiny-Gemma2ForCausalLM", + "trl-internal-testing/tiny-GemmaForCausalLM", + "trl-internal-testing/tiny-GPT2LMHeadModel", + "trl-internal-testing/tiny-GPTNeoXForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-LlamaForCausalLM-3", + "trl-internal-testing/tiny-MistralForCausalLM-0.1", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "trl-internal-testing/tiny-OPTForCausalLM", + "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", +] + +ALL_SEQ2SEQ_MODELS = [ + "trl-internal-testing/tiny-T5ForConditionalGeneration", + "trl-internal-testing/tiny-BartModel", +] + + +class BaseTester: + class VHeadModelTester(unittest.TestCase): + all_model_names = None + trl_model_class = None + transformers_model_class = None + + def test_value_head(self): + r""" + Test if the v-head is added to the model successfully + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertTrue(hasattr(model, "v_head")) + + def test_value_head_shape(self): + r""" + Test if the v-head has the correct shape + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertEqual(model.v_head.summary.weight.shape[0], 1) + + def test_value_head_init_random(self): + r""" + Test if the v-head has been randomly initialized. + We can check that by making sure the bias is different + than zeros by default. + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + self.assertFalse( + torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) + ) + + def test_value_head_not_str(self): + r""" + Test if the v-head is added to the model successfully, by passing a non `PretrainedModel` + as an argument to `from_pretrained`. + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + model = self.trl_model_class.from_pretrained(pretrained_model) + self.assertTrue(hasattr(model, "v_head")) + + def test_from_save_trl(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights + Including the additional modules (e.g. v_head) + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_trl_sharded(self): + """ + Test if the model can be saved and loaded from a directory and get the same weights - sharded case + """ + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + model_from_save = self.trl_model_class.from_pretrained(tmp_dir) + + # Check if the weights are the same + for key in model_from_save.state_dict(): + self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + + def test_from_save_transformers_sharded(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights - sharded case + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir, max_shard_size="1MB") + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained( + tmp_dir + ) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + def test_from_save_transformers(self): + """ + Test if the model can be saved and loaded using transformers and get the same weights. + We override the test of the super class to check if the weights are the same. + """ + for model_name in self.all_model_names: + transformers_model = self.trl_model_class.transformers_parent_class.from_pretrained(model_name) + + trl_model = self.trl_model_class.from_pretrained(model_name) + + with tempfile.TemporaryDirectory() as tmp_dir: + trl_model.save_pretrained(tmp_dir) + transformers_model_from_save = self.trl_model_class.transformers_parent_class.from_pretrained( + tmp_dir + ) + + # Check if the weights are the same + for key in transformers_model.state_dict(): + self.assertTrue( + torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] + ) + ) + + # Check if the trl model has the same keys as the transformers model + # except the v_head + for key in trl_model.state_dict(): + if "v_head" not in key: + self.assertIn(key, transformers_model.state_dict()) + # check if the weights are the same + self.assertTrue( + torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) + ) + + # check if they have the same modules + self.assertEqual( + set(transformers_model_from_save.state_dict().keys()), + set(transformers_model.state_dict().keys()), + ) + + +class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_CAUSAL_LM_MODELS + trl_model_class = AutoModelForCausalLMWithValueHead + transformers_model_class = AutoModelForCausalLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + @parameterized.expand(ALL_CAUSAL_LM_MODELS) + def test_generate(self, model_name): + r""" + Test if `generate` works for every model + """ + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids, generation_config=generation_config) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = ["lm_head", "embed_out", "output_layer"] + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), + "Can't test the model because it doesn't have any of the expected lm_head namings", + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(dummy_input) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + +class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, unittest.TestCase): + """ + Testing suite for v-head models. + """ + + all_model_names = ALL_SEQ2SEQ_MODELS + trl_model_class = AutoModelForSeq2SeqLMWithValueHead + transformers_model_class = AutoModelForSeq2SeqLM + + def tearDown(self): + # free memory + gc.collect() + + def test_inference(self): + r""" + Test if the model can be used for inference and outputs 3 values + - logits, loss, and value states + """ + EXPECTED_OUTPUT_SIZE = 3 + + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + outputs = model(input_ids, decoder_input_ids=decoder_input_ids) + + # Check if the outputs are of the right size - here + # we always output 3 values - logits, loss, and value states + self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + + def test_dropout_config(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + pretrained_model = self.transformers_model_class.from_pretrained(model_name) + pretrained_model.config.summary_dropout_prob = 0.5 + model = self.trl_model_class.from_pretrained(pretrained_model) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + + def test_dropout_kwargs(self): + r""" + Test if we instantiate a model by adding `summary_drop_prob` to the config + it will be added to the v_head + """ + for model_name in self.all_model_names: + v_head_kwargs = {"summary_dropout_prob": 0.5} + + model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) + + # Check if v head of the model has the same dropout as the config + self.assertEqual(model.v_head.dropout.p, 0.5) + + @parameterized.expand(ALL_SEQ2SEQ_MODELS) + def test_generate(self, model_name): + r""" + Test if `generate` works for every model + """ + generation_config = GenerationConfig(max_new_tokens=9) + model = self.trl_model_class.from_pretrained(model_name) + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + decoder_input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + + # Just check if the generation works + _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config) + + @unittest.skip("This test needs to be run manually due to HF token issue.") + def test_push_to_hub(self): + for model_name in self.all_model_names: + model = self.trl_model_class.from_pretrained(model_name) + if "sharded" in model_name: + model.push_to_hub(model_name + "-ppo", use_auth_token=True, max_shard_size="1MB") + else: + model.push_to_hub(model_name + "-ppo", use_auth_token=True) + + model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") + # check all keys + self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + + for name, param in model.state_dict().items(): + self.assertTrue( + torch.allclose(param, model_from_pretrained.state_dict()[name]), + f"Parameter {name} is not the same after push_to_hub and from_pretrained", + ) + + def test_transformers_bf16_kwargs(self): + r""" + Test if the transformers kwargs are correctly passed + Here we check that loading a model in half precision works as expected, i.e. the weights of + the `pretrained_model` attribute is loaded in half precision and you can run a dummy + forward pass without any issue. + """ + for model_name in self.all_model_names: + trl_model = self.trl_model_class.from_pretrained(model_name, torch_dtype=torch.bfloat16) + + lm_head_namings = self.trl_model_class.lm_head_namings + + self.assertTrue( + any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) + ) + + for lm_head_naming in lm_head_namings: + if hasattr(trl_model.pretrained_model, lm_head_naming): + self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + + dummy_input = torch.LongTensor([[0, 1, 0, 1]]) + + # check dummy forward pass works in half precision + _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) + + +class ReferenceModelTest(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel") + self.test_input = torch.tensor([[0, 1, 2, 3]]) + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) + self.layer_format = "pretrained_model.transformer.h.{layer}.attn.c_attn.weight" + + def test_independent_reference(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + last_layer_before = self.model.get_parameter(layer_1).data.clone() # the model only has 2 layers + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + last_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + last_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((last_layer_before == last_ref_layer_before).all()) + + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((last_ref_layer_before == last_ref_layer_after).all()) + + # optimized model changes + self.assertFalse((first_layer_before == first_layer_after).all()) + self.assertFalse((last_layer_before == last_layer_after).all()) + + def test_shared_layers(self): + layer_0 = self.layer_format.format(layer=0) + layer_1 = self.layer_format.format(layer=1) + + ref_model = create_reference_model(self.model, num_shared_layers=1) + + first_layer_before = self.model.get_parameter(layer_0).data.clone() + second_layer_before = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_before = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_before = ref_model.get_parameter(layer_1).data.clone() + + output = self.model(input_ids=self.test_input, labels=self.test_input) + output[1].backward() + self.optimizer.step() + + first_layer_after = self.model.get_parameter(layer_0).data.clone() + second_layer_after = self.model.get_parameter(layer_1).data.clone() + + first_ref_layer_after = ref_model.get_parameter(layer_0).data.clone() + second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() + + # before optimization ref and model are identical + self.assertTrue((first_layer_before == first_ref_layer_before).all()) + self.assertTrue((second_layer_before == second_ref_layer_before).all()) + + # ref model stays identical after optimization + self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) + self.assertTrue((second_ref_layer_before == second_ref_layer_after).all()) + + # first layer of optimized model stays the same + self.assertTrue((first_layer_before == first_layer_after).all()) + + # other layers in optimized model change + self.assertFalse((second_layer_before == second_layer_after).all()) diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a346f315a960c91a672ae81127bfe290c2d74e --- /dev/null +++ b/tests/test_nash_md_trainer.py @@ -0,0 +1,225 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import NashMDConfig, NashMDTrainer + +from .testing_utils import RandomPairwiseJudge, require_llm_blender + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestNashMDTrainer(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_nash_md_trainer_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_model_and_peft_config(self): + model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + model = get_peft_model(self.model, model_lora_config) + # we want only the "train adapter" to be trained + lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_train_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): + lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + # self.model from setUp is a base AutoModelForCausalLM + peft_model_instance = get_peft_model(self.model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=1, # Keep small for quick test + max_steps=2, # Few steps + learning_rate=5.0e-7, + eval_strategy="no", + report_to="none", + remove_unused_columns=False, # Important for the dummy dataset + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"] + + trainer = NashMDTrainer( + model=peft_model_instance, # Pass the already PEFT model + ref_model=None, # Implicit reference from peft_model_instance's base + reward_model=self.reward_model, # To trigger GeometricMixtureWrapper path + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + # peft_config is not passed, as model is already PEFT + ) + + trainer.train() + + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + @require_llm_blender + def test_nash_md_trainer_judge_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + judge = RandomPairwiseJudge() + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + judge=judge, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3a64326bbb55da725c89f8b7017c7faddfb1e4 --- /dev/null +++ b/tests/test_online_dpo_trainer.py @@ -0,0 +1,274 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import pytest +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft, require_torch_accelerator +from transformers.utils import is_peft_available + +from trl import OnlineDPOConfig, OnlineDPOTrainer + +from .testing_utils import RandomPairwiseJudge, require_llm_blender, require_vllm + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestOnlineDPOTrainer(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.reward_model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id, num_labels=1) + self.reward_tokenizer = AutoTokenizer.from_pretrained(self.reward_model_id) + self.reward_tokenizer.pad_token = self.reward_tokenizer.eos_token + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + def test_training_with_ref_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + def test_ref_model_is_model(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + with self.assertRaises(ValueError): + OnlineDPOTrainer( + model=self.model, + ref_model=self.model, # ref_model can't be the same as model + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + ) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_model_and_peft_config(self): + model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + model = get_peft_model(self.model, model_lora_config) + # we want only the "train adapter" to be trained + lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = OnlineDPOTrainer( + model=model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_class=self.reward_tokenizer, + peft_config=lora_train_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_llm_blender + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training_with_judge(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + judge=RandomPairwiseJudge(), + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + @require_torch_accelerator + @require_vllm + @pytest.mark.slow + def test_training_with_vllm(self, config_name): + model_id = "trl-internal-testing/small-Qwen2ForCausalLM-2.5" # We need a bigger model + model = AutoModelForCausalLM.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + output_dir=tmp_dir, + use_vllm=True, + gpu_memory_utilization=0.2, + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=model, + reward_model=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + processing_class=tokenizer, + reward_processing_class=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..754974b4b7e89e2f2278651d8c8eca766559531b --- /dev/null +++ b/tests/test_orpo_trainer.py @@ -0,0 +1,183 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.testing_utils import require_peft + +from trl import ORPOConfig, ORPOTrainer +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +class ORPOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + # get t5 as seq2seq example: + model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration" + self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) + self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id) + self.t5_tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + @parameterized.expand( + [ + ("qwen", "standard_preference"), + ("t5", "standard_implicit_prompt_preference"), + ("qwen", "conversational_preference"), + ("t5", "conversational_implicit_prompt_preference"), + ] + ) + def test_orpo_trainer(self, name, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + if name == "qwen": + model = self.model + tokenizer = self.tokenizer + elif name == "t5": + model = self.t5_model + tokenizer = self.t5_tokenizer + training_args.is_encoder_decoder = True + + trainer = ORPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + @parameterized.expand( + [ + ("standard_preference",), + ("standard_implicit_prompt_preference",), + ("conversational_preference",), + ("conversational_implicit_prompt_preference",), + ] + ) + @require_peft + def test_orpo_trainer_with_lora(self, config_name): + from peft import LoraConfig + + lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=4, + learning_rate=9e-1, + eval_strategy="steps", + beta=0.1, + report_to="none", + ) + + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = ORPOTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.equal(param, new_param)) + + def test_compute_metrics(self): + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + tokenizer.pad_token = tokenizer.eos_token + + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + + def dummy_compute_metrics(*args, **kwargs): + return {"test": 0.0} + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + output_dir=tmp_dir, + remove_unused_columns=False, + per_device_train_batch_size=2, + do_eval=True, + eval_strategy="steps", + eval_steps=1, + per_device_eval_batch_size=2, + report_to="none", + ) + + trainer = ORPOTrainer( + model=model, + args=training_args, + processing_class=tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + compute_metrics=dummy_compute_metrics, + ) + + trainer.train() + + self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py new file mode 100644 index 0000000000000000000000000000000000000000..230a757926e8a314b1696d5499d721e9eeb4ff62 --- /dev/null +++ b/tests/test_peft_models.py @@ -0,0 +1,205 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch +from transformers import AutoModelForCausalLM +from transformers.testing_utils import ( + require_peft, + require_torch_gpu_if_bnb_not_multi_backend_enabled, +) +from transformers.utils import is_peft_available + +from trl import AutoModelForCausalLMWithValueHead + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +@require_peft +class PeftModelTester(unittest.TestCase): + def setUp(self): + self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + def test_create_peft_model(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + _ = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + def test_peft_requires_grad(self): + r""" + Check that the value head of the returned model has requires_grad=True. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the value head has requires_grad=True + self.assertTrue(model.v_head.summary.weight.requires_grad) + + def test_check_peft_model_nb_trainable_params(self): + r""" + Check that the number of trainable parameters is correct. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) + + # Check that the number of trainable param for the non-peft model is correct + non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) + nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 2428641) + + def test_create_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) + + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) + + @require_torch_gpu_if_bnb_not_multi_backend_enabled + def test_create_bnb_peft_model_from_config(self): + r""" + Simply creates a peft model and checks that it can be loaded. + """ + from bitsandbytes.nn import Linear8bitLt + + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained( + self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True + ) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) + self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + + causal_lm_model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, load_in_8bit=True, device_map="auto" + ) + trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) + self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + + def test_save_pretrained_peft(self): + r""" + Check that the model can be saved and loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + model.save_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist" + ) + + # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights + self.assertTrue( + os.path.exists(f"{tmp_dir}/pytorch_model.bin"), f"{tmp_dir}/pytorch_model.bin does not exist" + ) + + # check that only keys that starts with `v_head` are in the dict + maybe_v_head = torch.load(f"{tmp_dir}/pytorch_model.bin", weights_only=True) + self.assertTrue( + all(k.startswith("v_head") for k in maybe_v_head.keys()), + f"keys in {tmp_dir}/pytorch_model.bin do not start with `v_head`", + ) + + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + + def test_load_pretrained_peft(self): + r""" + Check that the model saved with peft class interface can be loaded properly. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir) + + # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory + self.assertTrue( + os.path.isfile(f"{tmp_dir}/adapter_model.safetensors"), + f"{tmp_dir}/adapter_model.safetensors does not exist", + ) + self.assertTrue( + os.path.exists(f"{tmp_dir}/adapter_config.json"), f"{tmp_dir}/adapter_config.json does not exist" + ) + + # check all the weights are the same + for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): + if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: + self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + + def test_continue_training_peft_model(self): + r""" + Load peft and checks that it can continue training. + """ + causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) + pretrained_model = get_peft_model(causal_lm_model, self.lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + pretrained_model.save_pretrained(tmp_dir) + # set is_trainable to True + model = AutoModelForCausalLMWithValueHead.from_pretrained(tmp_dir, is_trainable=True) + # Check that the number of trainable parameters is correct + nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + self.assertEqual(nb_trainable_params, 905) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3f664de1b3a089ce27f78da78d87c8659a882b2d --- /dev/null +++ b/tests/test_ppo_trainer.py @@ -0,0 +1,178 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import PPOConfig, PPOTrainer +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +if is_peft_available(): + from peft import LoraConfig + + +class TestPPOTrainer(unittest.TestCase): + def setUp(self): + # Set up the models and tokenizer using the test model + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + if self.tokenizer.chat_template is None: + self.tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + # Add reward and value models as in ppo.py + reward_model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5" + self.value_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id, num_labels=1) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id, num_labels=1) + + # Load dataset + raw_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + def tokenize(example, tokenizer): + tokenized = tokenizer(text=example["prompt"]) + if tokenizer.eos_token_id is not None and tokenized["input_ids"][-1] != tokenizer.eos_token_id: + tokenized["input_ids"] = tokenized["input_ids"] + [tokenizer.eos_token_id] + tokenized["attention_mask"] = tokenized["attention_mask"] + [1] + return tokenized + + self.raw_dataset = raw_dataset.map(tokenize, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns="prompt") + + def test_basic_training(self): + """Test basic PPO training configuration and verify model updates.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args similar to example script + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + num_ppo_epochs=2, # Decrease number of PPO epochs to speed up test + report_to="none", + ) + + # Create trainer + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if not torch.allclose(initial_critic_weights[name], param.to("cpu")): + critic_weights_updated = True + break + + # Check if policy weights have been updated + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if not torch.allclose(initial_policy_weights[name], param.to("cpu")): + policy_weights_updated = True + break + + self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") + self.assertTrue(policy_weights_updated, "Policy weights were not updated during training") + + @require_peft + def test_peft_training(self): + """Test PPO training with PEFT configuration and verify model updates.""" + with tempfile.TemporaryDirectory() as tmp_dir: + # Capture initial weights + initial_critic_weights = {} + initial_policy_weights = {} + for name, param in self.value_model.named_parameters(): + initial_critic_weights[name] = param.clone().detach() + for name, param in self.model.named_parameters(): + initial_policy_weights[name] = param.clone().detach() + + # Configure training args + training_args = PPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=4, + per_device_eval_batch_size=2, + num_ppo_epochs=2, # Decrease number of PPO epochs to speed up test + report_to="none", + ) + + # Configure PEFT + peft_config = LoraConfig( + r=32, + lora_alpha=16, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + # Create trainer with PEFT + trainer = PPOTrainer( + args=training_args, + processing_class=self.tokenizer, + model=self.model, + ref_model=None, + reward_model=self.reward_model, + value_model=self.value_model, + train_dataset=self.raw_dataset["train"], + eval_dataset=self.raw_dataset["test"], + peft_config=peft_config, + ) + + # Train + trainer.train() + + # Check if critic weights have been updated + critic_weights_updated = False + for name, param in trainer.model.value_model.named_parameters(): + if name in initial_critic_weights and not torch.allclose( + initial_critic_weights[name], param.to("cpu") + ): + critic_weights_updated = True + break + + # Check if policy weights have been updated - for PEFT we check the LoRA weights + policy_weights_updated = False + for name, param in trainer.model.policy.named_parameters(): + if "lora" in name.lower() and param.requires_grad: # Only check LoRA weights + # New weights should be non-zero if they've been updated + if not torch.allclose(param, torch.zeros_like(param)): + policy_weights_updated = True + break + + self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") + self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training") diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a826a8a1d8245339983ccedbd4bddb43e16cfb3b --- /dev/null +++ b/tests/test_prm_trainer.py @@ -0,0 +1,360 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +from unittest.mock import MagicMock + +import torch +from datasets import Dataset, load_dataset +from parameterized import parameterized +from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import PRMConfig, PRMTrainer + + +if is_peft_available(): + from peft import LoraConfig, TaskType + + +class TestTokenizeRow(unittest.TestCase): + def setUp(self): + # Set up the mock tokenizer with specific behaviors + self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) + self.tokenizer.bos_token_id = 0 + self.tokenizer.eos_token_id = 2 + + def mock_encode(text, add_special_tokens): + token_map = { + "Which number is larger, 9.8 or 9.11?": [465, 6766, 318, 298], + "11 is greater than 8.": [4, 322, 12], + "Hence, 9.11 > 9.8.": [4995, 11, 22], + "\n": [1030], + "\n\n": [1030, 1030], + } + + return token_map[text] + + def mock_tokenizer_call(text, add_special_tokens): + return {"input_ids": mock_encode(text, add_special_tokens)} + + self.tokenizer.encode.side_effect = mock_encode + self.tokenizer.side_effect = mock_tokenizer_call + + def test_tokenize_row_no_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with no truncation + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_prompt_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + }, + ) + + def test_tokenize_row_train_on_last_step_only(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_prompt_length=None, + max_completion_length=None, + train_on_last_step_only=True, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + }, + ) + + def test_tokenize_row_prompt_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_prompt_length=3, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + }, + ) + + def test_tokenize_row_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=None, + max_prompt_length=None, + max_completion_length=6, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], + }, + ) + + def test_tokenize_row_prompt_completion_truncation(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method with truncation on the prompt and completion + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n", + max_length=9, + max_prompt_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], + }, + ) + + def test_tokenize_row_multi_token_separator(self): + # Define the input features + features = { + "prompt": "Which number is larger, 9.8 or 9.11?", + "completions": ["11 is greater than 8.", "Hence, 9.11 > 9.8."], + "labels": [True, False], + } + + # Call the method using multiple tokens as step_separator + result = PRMTrainer.tokenize_row( + features=features, + tokenizer=self.tokenizer, + step_separator="\n\n", + max_length=None, + max_prompt_length=None, + max_completion_length=None, + train_on_last_step_only=False, + is_eval=False, + ) + + self.assertEqual( + result, + { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], + }, + ) + + +class PRMTrainerTester(unittest.TestCase): + def setUp(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForTokenClassification.from_pretrained(model_id) + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + @parameterized.expand([True, False]) + def test_train_full(self, train_on_last_step_only): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig( + output_dir=tmp_dir, + report_to="none", + train_on_last_step_only=train_on_last_step_only, + ) + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + def test_train_full_pretokenized(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = Dataset.from_dict( + { + "labels": [ + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 0, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, 0, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, 0], + [-100, -100, -100, -100, -100, -100, -100, -100, 1], + [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + ], + "input_ids": [ + [46518, 374, 2664, 1091, 11, 1077, 752, 1744, 1112, 198, 27261, 13, 198], + [98923, 374, 2664, 1091, 11, 315, 3308, 11, 198, 17995, 13, 198, 1576, 31273, 12850, 13, 198], + [16374, 374, 2664, 1091, 1112, 1077, 594, 2506, 432, 6770, 11, 198, 6351, 13, 198], + [31137, 374, 2664, 1091, 979, 4362, 11, 198, 16965, 13, 198], + [31019, 374, 2664, 1091, 304, 3793, 315, 5944, 11, 198, 24034, 13, 198], + [98491, 374, 2664, 1091, 1112, 5310, 369, 91494, 13, 198], + [4418, 2897, 14579, 5310, 979, 3800, 1349, 432, 13, 198], + [20366, 5048, 7629, 944, 3281, 3322, 11, 7241, 1112, 198, 807, 1795, 279, 5601, 13, 198], + [15802, 14976, 487, 33327, 1045, 31787, 63443, 11, 198, 52400, 13, 198], + [13877, 1265, 2581, 1494, 49394, 11, 198, 7241, 20975, 91681, 13, 198], + [641, 279, 3579, 315, 71768, 11, 25066, 279, 61361, 311, 7942, 13, 198], + [7039, 374, 2664, 1091, 2937, 13, 198], + [26155, 374, 3545, 2664, 1091, 34933, 26537, 13, 198], + [2679, 279, 8129, 374, 4135, 311, 10339, 11, 432, 2578, 387, 264, 1661, 2884, 13, 198], + ], + } + ) + + training_args = PRMConfig(output_dir=tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @require_peft + def test_train_lora(self): + peft_config = LoraConfig( + task_type=TaskType.TOKEN_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = PRMTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # Check that the non trainable parameters have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + def test_tags(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") + training_args = PRMConfig(output_dir=tmp_dir, report_to="none") + trainer = PRMTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + self.assertEqual(trainer.model.model_tags, trainer._tag_names) diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..046363fdd3b6934f4eb9746a271342508f4fc256 --- /dev/null +++ b/tests/test_reward_trainer.py @@ -0,0 +1,235 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import Dataset, load_dataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template +from trl.trainer.reward_trainer import _tokenize + + +if is_peft_available(): + from peft import LoraConfig, TaskType + + +class RewardTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + self.model.config.pad_token_id = self.tokenizer.pad_token_id + + def test_preprocessing_conversational(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) + + def test_preprocessing_standard(self): + # No chat template, so we load a fresh tokenizer + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset + ) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer}) + self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:]) + + def test_train_full(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + def test_train_full_pretokenized(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if param.sum() != 0: # ignore 0 biases + self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + + @require_peft + def test_train_lora(self): + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # Check that the non trainable parameters have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + @require_peft + def test_train_lora_pretokenized(self): + peft_config = LoraConfig( + task_type=TaskType.SEQ_CLS, + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) + dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) + training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = RewardTrainer( + model=self.model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + peft_config=peft_config, + ) + previous_trainable_params = {} + previous_non_trainable_params = {} + + # due to a change in the way the modules to save are dealt in PEFT. + trainable_params_name = ["lora", "modules_to_save"] + + # check gradients are not None + for n, param in trainer.model.named_parameters(): + if any(t in n for t in trainable_params_name): + previous_trainable_params[n] = param.clone() + else: + previous_non_trainable_params[n] = param.clone() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + + # Check that the parameters have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + # Check that the non trainable parameters have not changed + for n, param in previous_non_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + + def test_margin(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset_dict = { + "input_ids_chosen": [ + torch.LongTensor([0, 1, 2]), + ], + "attention_mask_chosen": [ + torch.LongTensor([1, 1, 1]), + ], + "input_ids_rejected": [ + torch.LongTensor([0, 2]), + ], + "attention_mask_rejected": [ + torch.LongTensor([1, 1]), + ], + "margin": [ + torch.FloatTensor([1.0]), + ], + } + dummy_dataset = Dataset.from_dict(dummy_dataset_dict) + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + + batch = [dummy_dataset[0]] + batch = trainer.data_collator(batch) + batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True) + + l_val = -torch.nn.functional.logsigmoid( + outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] + ).mean() + + self.assertLess(abs(loss - l_val), 1e-6) + + def test_tags(self): + with tempfile.TemporaryDirectory() as tmp_dir: + dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + training_args = RewardConfig(output_dir=tmp_dir, report_to="none") + trainer = RewardTrainer( + model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset + ) + self.assertEqual(trainer.model.model_tags, trainer._tag_names) diff --git a/tests/test_rewards.py b/tests/test_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..44e628dd3b846ea44276e0b82ed746ae3344f2e9 --- /dev/null +++ b/tests/test_rewards.py @@ -0,0 +1,65 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from trl.rewards import think_format_reward + + +class ThinkFormatRewardTester(unittest.TestCase): + def test_valid_format(self): + completions = [ + "This is my reasoning.This is my answer.", # Simple, one-line reasoning + "\nThis is my reasoning.\n\nThis is my answer.", # Multiline reasoning + "\nThis is\nmy reasoning.\n\nThis is my answer.", # Multiline reasoning + "\nThis is my reasoning.\nThis is my answer.", # Reasoning including other tags + "\nThis is my answer.", # Empty reasoning + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid + rewards = think_format_reward(completions) + self.assertEqual(rewards, expected_rewards) + + def test_invalid_format(self): + completions = [ + "\nThis is my reasoning.\nThis is my answer.", # No closing + "This is my reasoning.\nThis is my answer.", # No closing + "This is my reasoning. This is my answer.", # No tags + "This is my reasoning.\nThis is my answer.", # No tags + "This is my reasoning.\nThis is my answer.", # No opening + "This is my reasoning.This is my answer.", # No opening + "Thisis my reasoning.\nThis is my answer.", # tag in the middle + "This ismy reasoning.This is my answer.", # Nested tags + "This is\nmy\nreasoning.\nThis is my answer.", # Multiline + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid + rewards = think_format_reward(completions) + self.assertEqual(rewards, expected_rewards) + + def test_mixed_format(self): + completions = [ + "This is my reasoning.This is my answer.", # Valid + "\nThis is my reasoning.\n\nThis is my answer.", # Valid + "This is my reasoning.\nThis is my answer.", # Invalid + "This is my reasoning. This is my answer.", # Invalid + ] + completions = [[{"content": completion}] for completion in completions] + expected_rewards = [1.0, 1.0, 0.0, 0.0] + rewards = think_format_reward(completions) + self.assertEqual(rewards, expected_rewards) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rich_progress_callback.py b/tests/test_rich_progress_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..f6e5a9a8983919f229de83561cb182b7d8efa99d --- /dev/null +++ b/tests/test_rich_progress_callback.py @@ -0,0 +1,69 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +import torch.nn as nn +from datasets import Dataset +from transformers import Trainer, TrainingArguments + +from trl.trainer.callbacks import RichProgressCallback + +from .testing_utils import require_rich + + +class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.a = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): + return self.a * x + + +@require_rich +class TestRichProgressCallback(unittest.TestCase): + def setUp(self): + self.dummy_model = DummyModel() + self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5) + self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101) + + def test_rich_progress_callback_logging(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_eval_batch_size=2, + per_device_train_batch_size=2, + num_train_epochs=4, + eval_strategy="steps", + eval_steps=1, + logging_strategy="steps", + logging_steps=1, + save_strategy="no", + report_to="none", + disable_tqdm=True, + ) + callbacks = [RichProgressCallback()] + trainer = Trainer( + model=self.dummy_model, + train_dataset=self.dummy_train_dataset, + eval_dataset=self.dummy_val_dataset, + args=training_args, + callbacks=callbacks, + ) + + trainer.train() + trainer.train() diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..f267a1cd4fdf1724171146b9af148f4562c02380 --- /dev/null +++ b/tests/test_rloo_trainer.py @@ -0,0 +1,213 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import RLOOConfig, RLOOTrainer + + +class RLOOTrainerTester(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + + self.policy_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id) + self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, padding_side="left") + self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + def test_rloo_checkpoint(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + total_episodes=1, + report_to="none", + ) + + dummy_text = [{"content": "Hello World!", "role": "user"}] + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": dummy_data}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=self.reward_model, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + trainer._save_checkpoint(trainer.model, trial=None) + + def test_rloo_reward(self): + local_batch_size = 3 + rloo_k = 4 + sequence_length = 5 # Add sequence length for testing token-level rewards + + # fmt: off + rlhf_reward = torch.tensor([ + 1, 2, 3, # first rlhf reward for three prompts + 2, 3, 4, # second rlhf reward for three prompts + 5, 6, 7, # third rlhf reward for three prompts + 8, 9, 10, # fourth rlhf reward for three prompts + ]).float() + + # Create padding mask where 1 indicates valid token, 0 indicates padding + padding_mask = torch.ones(local_batch_size * rloo_k, sequence_length) + # Set padding based on sequence lengths + sequence_lengths = torch.tensor([ + 3, 4, 3, # lengths for first batch + 4, 3, 4, # lengths for second batch + 3, 4, 3, # lengths for third batch + 4, 3, 4, # lengths for fourth batch + ]) + for i, length in enumerate(sequence_lengths): + padding_mask[i, length:] = 0 + + # Add kl tensor for testing token-level rewards + kl = torch.ones(local_batch_size * rloo_k, sequence_length) # Dummy KL values + # fmt: on + + # Test token-level KL rewards following OpenRLHF implementation + kl_coef = 0.1 + kl_reward = -kl_coef * kl + + # Find last non-padded position + eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True) + + # Create last reward tensor + last_reward = torch.zeros_like(kl) + last_reward.scatter_(dim=1, index=eos_indices, src=rlhf_reward.reshape(-1, 1)) + + # Test last_reward - should have rlhf_reward at the last non-padded position + for i, (length, reward) in enumerate(zip(sequence_lengths, rlhf_reward)): + # Check reward is at correct position + self.assertEqual(last_reward[i, length - 1].item(), reward.item()) + # Check zeros elsewhere + self.assertTrue(torch.all(last_reward[i, : length - 1] == 0)) + self.assertTrue(torch.all(last_reward[i, length:] == 0)) + + # Combine rewards + reward = last_reward + kl_reward + non_score_reward = kl_reward.sum(1) + token_level_rlhf_reward = reward.sum(1) + + # Test reward components + # KL reward should be -0.1 for each token in sequence length + expected_kl_reward = -0.1 * sequence_length # Each position gets -0.1 KL reward + torch.testing.assert_close(non_score_reward, torch.tensor(expected_kl_reward).expand_as(non_score_reward)) + + # Total reward should be rlhf_reward + kl_reward + expected_total = rlhf_reward + expected_kl_reward + torch.testing.assert_close(token_level_rlhf_reward, expected_total) + + # Test sequence-level rewards (existing test) + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) + advantages = torch.zeros_like(rlhf_reward) + for i in range(0, len(advantages), local_batch_size): + other_response_rlhf_rewards = [] + for j in range(0, len(advantages), local_batch_size): + if i != j: + other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size]) + advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack( + other_response_rlhf_rewards + ).mean(0) + self.assertLess((1 - (2 + 5 + 8) / 3 - advantages[0].item()), 1e-6) + self.assertLess((6 - (3 + 2 + 9) / 3 - advantages[7].item()), 1e-6) + + # Test vectorized implementation + rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size) + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1) + vec_advantages = rlhf_reward - baseline + torch.testing.assert_close(vec_advantages.flatten(), advantages) + + def test_rloo_training(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + total_episodes=1, + num_train_epochs=1, + max_steps=2, + report_to="none", + ) + + # Create a simple dataset + dummy_text = [{"content": "Hello World!", "role": "user"}] + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=self.reward_model, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Test that training completes without errors + trainer.train() + + # Check if objective/rlhf_reward is available + self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) + + def test_rloo_training_with_custom_reward(self): + # dummy reward function + def reward_function(texts): + # based on length of text + rewards = [len(text) for text in texts] + return rewards + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RLOOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + per_device_eval_batch_size=2, + total_episodes=1, + num_train_epochs=1, + max_steps=2, + report_to="none", + ) + + # Create a simple dataset + dummy_text = [{"content": "Hello World!", "role": "user"}] + dummy_data = self.tokenizer.apply_chat_template(dummy_text) + dummy_dataset = Dataset.from_dict({"input_ids": [dummy_data, dummy_data]}) + + trainer = RLOOTrainer( + config=training_args, + policy=self.policy_model, + reward_model=reward_function, + ref_policy=self.policy_ref_model, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + eval_dataset=dummy_dataset, + ) + + # Test that training completes without errors + trainer.train() + + # Check if objective/rlhf_reward is available + self.assertIn("objective/rlhf_reward", trainer.state.log_history[-1]) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0988de6274a9108a2ef56ac187a76905d13d8e87 --- /dev/null +++ b/tests/test_sft_trainer.py @@ -0,0 +1,1352 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import tempfile +import unittest + +import numpy as np +import torch +from datasets import Dataset, Image, Sequence, load_dataset +from parameterized import parameterized +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, + LlavaForConditionalGeneration, + TrainingArguments, + is_vision_available, +) +from transformers.testing_utils import require_flash_attn, require_peft, require_vision +from transformers.utils import is_peft_available + +from trl import SFTConfig, SFTTrainer +from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM +from trl.trainer.sft_trainer import DataCollatorForLanguageModeling + + +def formatting_prompts_func(example): + text = f"### Question: {example['question']}\n ### Answer: {example['answer']}" + return text + + +def formatting_func_for_pretokenized(example): + return example["input_ids"] + + +if is_peft_available(): + from peft import LoraConfig, PeftModel, get_peft_model + +if is_vision_available(): + from PIL import Image as PILImage + + +class TestDataCollatorForLanguageModeling(unittest.TestCase): + def test_basic_padding(self): + """Test basic padding functionality without completion masks.""" + self.collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_completion_mask(self): + """Test completion mask functionality.""" + self.collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) + + def test_completion_only_loss_disabled(self): + """Test behavior when completion_only_loss is disabled.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, completion_only_loss=False) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + {"input_ids": [4, 5], "completion_mask": [0, 1]}, + ] + + result = collator(examples) + + # Labels should not be masked when completion_only_loss=False + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_padding_free_mode(self): + """Test padding-free mode where sequences are concatenated.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]])) + + def test_padding_free_with_completion_mask(self): + """Test padding-free mode with completion masks.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + examples = [ + {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + {"input_ids": [4, 5], "completion_mask": [1, 1]}, + ] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) + torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]])) + + def test_pad_to_multiple_of(self): + """Test padding to multiple of specified value.""" + collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) + + def test_custom_position_ids(self): + """Test handling of custom position IDs in examples.""" + self.collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3], "position_ids": [0, 0, 1]}, {"input_ids": [4, 5], "position_ids": [0, 1]}] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + def test_single_example(self): + """Test collator with a single example.""" + self.collator = DataCollatorForLanguageModeling(pad_token_id=0) + examples = [{"input_ids": [1, 2, 3, 4]}] + + result = self.collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) + + def test_different_pad_token_id(self): + """Test with different pad token ID.""" + collator = DataCollatorForLanguageModeling(pad_token_id=999) + examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}] + + result = collator(examples) + + torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) + torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) + torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]])) + torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) + + +class SFTTrainerTester(unittest.TestCase): + r""" """ + + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.dummy_dataset = Dataset.from_dict( + { + "question": [ + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to code?", + "Does llamas know how to fly?", + "Does llamas know how to talk?", + "Does llamas know how to swim?", + ], + "answer": [ + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "Yes, llamas are very good at coding.", + "No, llamas can't fly.", + "Yes, llamas are very good at talking.", + "No, llamas can't swim.", + ], + "text": [ + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to code?\n ### Answer: Yes, llamas are very good at coding.", + "### Question: Does llamas know how to fly?\n ### Answer: No, llamas can't fly.", + "### Question: Does llamas know how to talk?\n ### Answer: Yes, llamas are very good at talking.", + "### Question: Does llamas know how to swim?\n ### Answer: No, llamas can't swim.", + ], + } + ) + self.dummy_tokenized_dataset = Dataset.from_dict( + { + "input_ids": [ + self.tokenizer.encode( + "TRL is a library to post-train LLMs and diffusion models with methods such as Supervised Fine-tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO)." + ) + ] + * 10 + } + ) + + self.conversational_lm_dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling") + self.standard_prompt_completion_dataset = load_dataset( + "trl-internal-testing/zen", "standard_prompt_completion" + ) + + if is_vision_available(): + self.dummy_vsft_instruction_dataset = Dataset.from_dict( + { + "messages": [ + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is random noise."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "2"}], + }, + ], + [ + { + "role": "user", + "content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is random noise."}], + }, + ], + ], + "images": [ + [PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")], + [PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")], + ], + } + ) + self.dummy_vsft_instruction_dataset.cast_column("images", Sequence(Image())) + self.dummy_vsft_instruction_dataset = self.dummy_vsft_instruction_dataset.cast_column( + "images", Sequence(Image()) + ) + + self.train_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_dataset, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + self.eval_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_dataset, + formatting_func=formatting_prompts_func, + seq_length=16, + num_of_sequences=16, + ) + + self.train_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + self.eval_dataset_from_pretokenized = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + seq_length=16, + num_of_sequences=16, + formatting_func=formatting_func_for_pretokenized, + ) + + def test_constant_length_dataset_with_pretokenized_data(self): + constant_len_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_tokenized_dataset, + formatting_func=formatting_func_for_pretokenized, + ) + + assert len(constant_len_dataset) == len(self.dummy_tokenized_dataset) + assert len(constant_len_dataset) > 0 + + for example in constant_len_dataset: + assert "input_ids" in example + assert "labels" in example + + assert len(example["input_ids"]) == constant_len_dataset.seq_length + assert len(example["labels"]) == constant_len_dataset.seq_length + + decoded_text = self.tokenizer.decode(example["input_ids"]) + assert ("TRL" in decoded_text) and ("(DPO)" in decoded_text) + + def test_constant_length_dataset(self): + formatted_dataset = ConstantLengthDataset( + self.tokenizer, + self.dummy_dataset, + formatting_func=formatting_prompts_func, + ) + + self.assertEqual(len(formatted_dataset), len(self.dummy_dataset)) + self.assertGreater(len(formatted_dataset), 0) + + for example in formatted_dataset: + self.assertIn("input_ids", example) + self.assertIn("labels", example) + + self.assertEqual(len(example["input_ids"]), formatted_dataset.seq_length) + self.assertEqual(len(example["labels"]), formatted_dataset.seq_length) + + decoded_text = self.tokenizer.decode(example["input_ids"]) + self.assertTrue(("Question" in decoded_text) and ("Answer" in decoded_text)) + + def test_backward_compatibility(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = TrainingArguments( + output_dir=tmp_dir, + per_device_train_batch_size=2, + hub_token="not_a_real_token", + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + formatting_func=formatting_prompts_func, + ) + + self.assertEqual(trainer.args.hub_token, training_args.hub_token) + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + trainer.train() + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + def test_with_pretokenized_data_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + packing=True, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset_from_pretokenized, + ) + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + def test_uncorrect_data(self): + with tempfile.TemporaryDirectory() as tmp_dir: + # Shoud work as SFTTrainer natively supports conversational lm dataset + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=32, # make sure there is at least 1 packed sequence + packing=True, + report_to="none", + ) + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.conversational_lm_dataset["train"], + ) + + # Same, but without packing + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + packing=False, + report_to="none", + ) + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.conversational_lm_dataset["train"], + ) + + # Same, but with packing with `max_length` + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=16, # make sure there is at least 1 packed sequence + packing=True, + report_to="none", + ) + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.standard_prompt_completion_dataset["train"], + ) + + # Same but with prompt completion dataset + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + packing=False, + report_to="none", + ) + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.standard_prompt_completion_dataset["train"], + ) + + # Should work as dummy dataset are supported with a formatting function + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=32, # make sure there is at least 1 packed sequence + packing=True, + report_to="none", + ) + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + ) + + def test_sft_trainer_with_model_num_train_epochs(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + num_train_epochs=2, + per_device_train_batch_size=2, + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + num_train_epochs=2, + max_length=16, + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.dummy_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + num_train_epochs=2, + per_device_train_batch_size=2, + max_length=16, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.dummy_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + def test_with_model_(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=16, + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # with formatting_func + packed + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=16, + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + formatting_func=formatting_prompts_func, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=16, + report_to="none", + ) + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.dummy_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + def test_with_multiple_eval_datasets(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + eval_strategy="steps", + eval_steps=3, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + eval_dataset={ + "data1": self.eval_dataset, + "data2": self.eval_dataset, + }, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + self.assertIsNotNone(trainer.state.log_history[0]["eval_data1_loss"]) + self.assertIsNotNone(trainer.state.log_history[1]["eval_data2_loss"]) + + def test_data_collator_completion_lm(self): + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=self.tokenizer, mlm=False) + + text = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + last_pad_idx = np.where(labels == -100)[1][-1] + result_text = self.tokenizer.decode(batch["input_ids"][0, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + response_template = "### Response:\n" + data_collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, mlm=False) + + text1 = """\n\n### Instructions:\nHello all this should be masked\n\n### Response:\nI have not been masked correctly.""" + text2 = """\n\n### Instructions:\nThis is another longer text that should also be masked. This text is significantly longer than the previous one.\n\n### Response:\nI have not been masked correctly.""" + + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + + for i in range(2): + labels = batch["labels"][i] + last_pad_idx = np.where(labels == -100)[0][-1] + result_text = tokenizer.decode(batch["input_ids"][i, last_pad_idx + 1 :]) + self.assertEqual(result_text, "I have not been masked correctly.") + + def test_data_collator_chat_completion_lm(self): + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=self.tokenizer, + mlm=False, + ) + + text = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text = self.tokenizer(text) + + examples = [encoded_text] + + batch = data_collator(examples) + labels = batch["labels"] + non_masked_tokens = batch["input_ids"][labels != -100] + result_text = self.tokenizer.decode(non_masked_tokens) + self.assertEqual(result_text, " I should not be masked. I should not be masked too.") + + def test_data_collator_chat_completion_lm_with_multiple_text(self): + tokenizer = copy.deepcopy(self.tokenizer) + tokenizer.padding_side = "left" + + instruction_template = "### Human:" + assistant_template = "### Assistant:" + data_collator = DataCollatorForCompletionOnlyLM( + response_template=assistant_template, + instruction_template=instruction_template, + tokenizer=tokenizer, + mlm=False, + ) + + text1 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.""" + text2 = """### Human: Hello all this should be masked.### Assistant: I should not be masked.### Human: All this should be masked too.### Assistant: I should not be masked too.""" + encoded_text1 = tokenizer(text1) + encoded_text2 = tokenizer(text2) + + examples = [encoded_text1, encoded_text2] + + batch = data_collator(examples) + labels = batch["labels"] + input_ids = batch["input_ids"] + + non_masked_tokens1 = input_ids[0][labels[0] != -100] + result_text1 = tokenizer.decode(non_masked_tokens1) + self.assertEqual(result_text1, " I should not be masked.") + + non_masked_tokens2 = input_ids[1][labels[1] != -100] + result_text2 = tokenizer.decode(non_masked_tokens2) + self.assertEqual(result_text2, " I should not be masked. I should not be masked too.") + + def test_with_model_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + neftune_noise_alpha=5, + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model=self.model, + args=training_args, + train_dataset=self.train_dataset, + ) + + trainer.model = trainer._activate_neftune(trainer.model) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + # Make sure forward pass works fine + _ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0) + + @require_peft + def test_peft_str(self): + with tempfile.TemporaryDirectory() as tmp_dir: + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + + training_args = SFTConfig( + packing=True, + output_dir=tmp_dir, + report_to="none", + ) + + _ = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + peft_config=peft_config, + ) + + @require_peft + def test_peft_sft_trainer(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + packing=True, + report_to="none", + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + peft_config=peft_config, + ) + + self.assertTrue(isinstance(trainer.model, PeftModel)) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + @require_peft + def test_peft_and_gradient_checkpointing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + gradient_checkpointing=True, + report_to="none", + ) + + peft_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, task_type="CAUSAL_LM") + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + peft_config=peft_config, + ) + + self.assertIsInstance(trainer.model, PeftModel) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + @require_peft + def test_peft_neftune(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + neftune_noise_alpha=5, + packing=True, + report_to="none", + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + peft_config=peft_config, + ) + + trainer.model = trainer._activate_neftune(trainer.model) + + self.assertIsInstance(trainer.model, PeftModel) + + device = trainer.model.get_input_embeddings().weight.device + trainer.model.train() + + torch.random.manual_seed(42) + embeds_neftune = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + torch.random.manual_seed(24) + embeds_neftune_2 = trainer.model.get_input_embeddings()(torch.LongTensor([[1, 0, 1]]).to(device)) + + self.assertFalse(torch.allclose(embeds_neftune, embeds_neftune_2)) + self.assertGreater(len(trainer.model.get_input_embeddings()._forward_hooks), 0) + + trainer.neftune_hook_handle.remove() + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Make sure forward pass works fine to check if embeddings forward is not broken. + trainer.model(torch.LongTensor([[1, 0, 1]]).to(device)) + self.assertEqual(len(trainer.model.get_input_embeddings()._forward_hooks), 0) + + @require_peft + def test_peft_tag(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + gradient_checkpointing=True, + packing=True, + report_to="none", + ) + + peft_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + task_type="CAUSAL_LM", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + peft_config=peft_config, + ) + + for tag in ["sft", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + @require_peft + def test_tag(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + gradient_checkpointing=True, + packing=True, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + ) + + for tag in ["sft", "trl"]: + self.assertIn(tag, trainer.model.model_tags) + + def test_only_train_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + gradient_checkpointing=True, + packing=True, + max_length=128, # make sure there is at least 1 packed sequence + eval_packing=False, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.conversational_lm_dataset["train"], + eval_dataset=self.conversational_lm_dataset["test"], + ) + + self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs + self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) + + def test_eval_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=128, # make sure there is at least 1 packed sequence + packing=True, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.conversational_lm_dataset["train"], + eval_dataset=self.conversational_lm_dataset["test"], + ) + + self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs + self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs + + def test_no_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_length=128, # make sure there is at least 1 packed sequence + packing=False, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.conversational_lm_dataset["train"], + eval_dataset=self.conversational_lm_dataset["test"], + ) + + self.assertEqual(len(trainer.train_dataset["input_ids"]), len(self.conversational_lm_dataset["train"])) + self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"])) + + @require_vision + def test_skip_prepare_dataset(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + remove_unused_columns=False, + dataset_kwargs={"skip_prepare_dataset": True}, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.dummy_vsft_instruction_dataset, + ) + self.assertEqual(trainer.train_dataset.features, self.dummy_vsft_instruction_dataset.features) + + def test_skip_prepare_dataset_with_no_packing(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + remove_unused_columns=False, + packing=False, + dataset_kwargs={"skip_prepare_dataset": True}, + report_to="none", + ) + + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.dummy_dataset, + ) + self.assertEqual(trainer.train_dataset.features, self.dummy_dataset.features) + + @require_vision + def test_llava(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + remove_unused_columns=False, + dataset_kwargs={"skip_prepare_dataset": True}, + report_to="none", + ) + tiny_llava = LlavaForConditionalGeneration.from_pretrained( + "trl-internal-testing/tiny-LlavaForConditionalGeneration" + ) + processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlavaForConditionalGeneration") + + processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" + + def collate_fn(examples): + # Get the texts and images, and apply the chat template + texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples] + images = [example["images"][0] for example in examples] + + # Tokenize the texts and process the images + batch = processor(texts, images, return_tensors="pt", padding=True) + + # The labels are the input_ids, and we mask the padding tokens in the loss computation + labels = batch["input_ids"].clone() + labels[labels == processor.tokenizer.pad_token_id] = -100 + batch["labels"] = labels + + return batch + + trainer = SFTTrainer( + model=tiny_llava, + args=training_args, + data_collator=collate_fn, + train_dataset=self.dummy_vsft_instruction_dataset, + ) + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + def test_torch_dtype(self): + # See https://github.com/huggingface/trl/issues/1751 + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + model_init_kwargs={"torch_dtype": torch.float16}, + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=self.train_dataset, + formatting_func=formatting_prompts_func, + ) + self.assertEqual(trainer.model.config.torch_dtype, torch.float16) + + +# This new tester aims to replace the first one at some point +class SFTTrainerTester2(unittest.TestCase): + def test_train(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model(self): + # Instantiate the model + model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_model_torch_dtype(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig( + output_dir=tmp_dir, model_init_kwargs={"torch_dtype": torch.float16}, report_to="none" + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + # Check the torch dtype + self.assertEqual(new_param.dtype, torch.float16) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_peft + def test_train_peft_model(self): + # Get the base model + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + model = AutoModelForCausalLM.from_pretrained(model_id) + + # Get the base model parameter names + base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] + + # Turn the model into a peft model + lora_config = LoraConfig() + model = get_peft_model(model, lora_config) + + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if n in base_param_names: # We expect the base model parameters to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + elif ( + "base_layer" not in n + ): # We expect the peft parameters to be different (except for the base layer) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_non_chatml_conversational_data(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + # Rename role/content to from/value to ensure SFT works with non-chatML conversational data + def rename_fields(example: list[dict]): + return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]} + + dataset = dataset.map(rename_fields, remove_columns="messages") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_pretokenized_data(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + def tokenize_example(example): + return tokenizer(example["text"]) + + # Apply tokenization + tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_iterable_dataset(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train", streaming=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, max_steps=3, report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + def test_train_with_data_collator_for_completion_only_and_padding_free(self): + # Get the dataset + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + + tokenizer = AutoTokenizer.from_pretrained(model_id) + response_template = "<|im_start|>assistant\n" + collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer, padding_free=True) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig(output_dir=tmp_dir, report_to="none") + trainer = SFTTrainer(model=model_id, args=training_args, train_dataset=dataset, data_collator=collator) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @require_flash_attn + def test_train_padding_free(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig( + output_dir=tmp_dir, + padding_free=True, + model_init_kwargs={"attn_implementation": "flash_attention_2"}, + bf16=True, # flash_attention_2 only supports bf16 and fp16 + report_to="none", + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + + @parameterized.expand([("ffd",), ("wrapped",)]) + def test_train_packing(self, packing_strategy): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Initialize the trainer + training_args = SFTConfig( + output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none" + ) + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py new file mode 100644 index 0000000000000000000000000000000000000000..df1ef02614a1bbdbc7f4b44d70ee85becaf27556 --- /dev/null +++ b/tests/test_trainers_args.py @@ -0,0 +1,410 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import ( + BCOConfig, + BCOTrainer, + CPOConfig, + CPOTrainer, + DPOConfig, + DPOTrainer, + KTOConfig, + KTOTrainer, + NashMDConfig, + NashMDTrainer, + OnlineDPOConfig, + OnlineDPOTrainer, + ORPOConfig, + ORPOTrainer, + RewardConfig, + RewardTrainer, + SFTConfig, + SFTTrainer, + XPOConfig, + XPOTrainer, +) + +from .testing_utils import require_sklearn + + +class TrainerArgTester(unittest.TestCase): + @require_sklearn + def test_bco(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = BCOConfig( + tmp_dir, + max_length=256, + max_prompt_length=64, + max_completion_length=64, + beta=0.5, + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + # generate_during_eval=True, # ignore this one, it requires wandb + is_encoder_decoder=True, + precompute_ref_log_probs=True, + model_init_kwargs={"trust_remote_code": True}, + ref_model_init_kwargs={"trust_remote_code": True}, + dataset_num_proc=4, + prompt_sample_size=512, + min_density_ratio=0.2, + max_density_ratio=20.0, + ) + trainer = BCOTrainer( + model=model_id, + ref_model=model_id, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.max_prompt_length, 64) + self.assertEqual(trainer.args.max_completion_length, 64) + self.assertEqual(trainer.args.beta, 0.5) + self.assertEqual(trainer.args.label_pad_token_id, -99) + self.assertEqual(trainer.args.padding_value, -99) + self.assertEqual(trainer.args.truncation_mode, "keep_start") + # self.assertEqual(trainer.args.generate_during_eval, True) + self.assertEqual(trainer.args.is_encoder_decoder, True) + self.assertEqual(trainer.args.precompute_ref_log_probs, True) + self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.dataset_num_proc, 4) + self.assertEqual(trainer.args.prompt_sample_size, 512) + self.assertEqual(trainer.args.min_density_ratio, 0.2) + self.assertEqual(trainer.args.max_density_ratio, 20.0) + + def test_cpo(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = CPOConfig( + tmp_dir, + max_length=256, + max_prompt_length=64, + max_completion_length=64, + beta=0.5, + label_smoothing=0.5, + loss_type="hinge", + disable_dropout=False, + cpo_alpha=0.5, + simpo_gamma=0.2, + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + # generate_during_eval=True, # ignore this one, it requires wandb + is_encoder_decoder=True, + model_init_kwargs={"trust_remote_code": True}, + dataset_num_proc=4, + ) + trainer = CPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.max_prompt_length, 64) + self.assertEqual(trainer.args.max_completion_length, 64) + self.assertEqual(trainer.args.beta, 0.5) + self.assertEqual(trainer.args.label_smoothing, 0.5) + self.assertEqual(trainer.args.loss_type, "hinge") + self.assertEqual(trainer.args.disable_dropout, False) + self.assertEqual(trainer.args.cpo_alpha, 0.5) + self.assertEqual(trainer.args.simpo_gamma, 0.2) + self.assertEqual(trainer.args.label_pad_token_id, -99) + self.assertEqual(trainer.args.padding_value, -99) + self.assertEqual(trainer.args.truncation_mode, "keep_start") + # self.assertEqual(trainer.args.generate_during_eval, True) + self.assertEqual(trainer.args.is_encoder_decoder, True) + self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.dataset_num_proc, 4) + + def test_dpo(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = DPOConfig( + tmp_dir, + beta=0.5, + label_smoothing=0.5, + loss_type="hinge", + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + max_length=256, + max_prompt_length=64, + max_completion_length=64, + disable_dropout=False, + # generate_during_eval=True, # ignore this one, it requires wandb + precompute_ref_log_probs=True, + dataset_num_proc=4, + model_init_kwargs={"trust_remote_code": True}, + ref_model_init_kwargs={"trust_remote_code": True}, + model_adapter_name="dummy_adapter", + ref_adapter_name="dummy_adapter", + reference_free=True, + force_use_ref_model=True, + f_divergence_type="js_divergence", + f_alpha_divergence_coef=0.5, + # sync_ref_model=True, # cannot be True when precompute_ref_log_probs=True. Don't test this. + ref_model_mixup_alpha=0.5, + ref_model_sync_steps=32, + rpo_alpha=0.5, + discopop_tau=0.1, + ) + trainer = DPOTrainer( + model=model_id, + ref_model=model_id, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + self.assertEqual(trainer.args.beta, 0.5) + self.assertEqual(trainer.args.label_smoothing, 0.5) + self.assertEqual(trainer.args.loss_type, "hinge") + self.assertEqual(trainer.args.label_pad_token_id, -99) + self.assertEqual(trainer.args.padding_value, -99) + self.assertEqual(trainer.args.truncation_mode, "keep_start") + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.max_prompt_length, 64) + self.assertEqual(trainer.args.max_completion_length, 64) + self.assertEqual(trainer.args.disable_dropout, False) + # self.assertEqual(trainer.args.generate_during_eval, True) + self.assertEqual(trainer.args.precompute_ref_log_probs, True) + self.assertEqual(trainer.args.dataset_num_proc, 4) + self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.model_adapter_name, "dummy_adapter") + self.assertEqual(trainer.args.ref_adapter_name, "dummy_adapter") + self.assertEqual(trainer.args.reference_free, True) + self.assertEqual(trainer.args.force_use_ref_model, True) + self.assertEqual(trainer.args.f_divergence_type, "js_divergence") + self.assertEqual(trainer.args.f_alpha_divergence_coef, 0.5) + # self.assertEqual(trainer.args.sync_ref_model, True) + self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5) + self.assertEqual(trainer.args.ref_model_sync_steps, 32) + self.assertEqual(trainer.args.rpo_alpha, 0.5) + self.assertEqual(trainer.args.discopop_tau, 0.1) + + def test_kto(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = KTOConfig( + tmp_dir, + max_length=256, + max_prompt_length=64, + max_completion_length=64, + beta=0.5, + desirable_weight=0.5, + undesirable_weight=0.5, + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + # generate_during_eval=True, # ignore this one, it requires wandb + is_encoder_decoder=True, + precompute_ref_log_probs=True, + model_init_kwargs={"trust_remote_code": True}, + ref_model_init_kwargs={"trust_remote_code": True}, + dataset_num_proc=4, + ) + trainer = KTOTrainer( + model=model_id, + ref_model=model_id, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.max_prompt_length, 64) + self.assertEqual(trainer.args.max_completion_length, 64) + self.assertEqual(trainer.args.beta, 0.5) + self.assertEqual(trainer.args.desirable_weight, 0.5) + self.assertEqual(trainer.args.undesirable_weight, 0.5) + self.assertEqual(trainer.args.label_pad_token_id, -99) + self.assertEqual(trainer.args.padding_value, -99) + self.assertEqual(trainer.args.truncation_mode, "keep_start") + # self.assertEqual(trainer.args.generate_during_eval, True) + self.assertEqual(trainer.args.is_encoder_decoder, True) + self.assertEqual(trainer.args.precompute_ref_log_probs, True) + self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) + self.assertEqual(trainer.args.dataset_num_proc, 4) + + @parameterized.expand([(False,), (True,)]) + def test_nash_md(self, mixtures_coef_list): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + tmp_dir, + mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6], + ) + trainer = NashMDTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + ref_model=ref_model, + reward_model=reward_model, + train_dataset=dataset, + ) + self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6]) + + @parameterized.expand([(False,), (True,)]) + def test_online_dpo(self, beta_list): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = OnlineDPOConfig( + tmp_dir, + max_new_tokens=42, + temperature=0.5, + missing_eos_penalty=0.33, + beta=0.6 if not beta_list else [0.6, 0.7], + loss_type="hinge", + dataset_num_proc=4, + ) + trainer = OnlineDPOTrainer( + model=model, + ref_model=ref_model, + reward_model=reward_model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + reward_processing_class=tokenizer, + ) + self.assertEqual(trainer.args.max_new_tokens, 42) + self.assertEqual(trainer.args.temperature, 0.5) + self.assertEqual(trainer.args.missing_eos_penalty, 0.33) + self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7]) + self.assertEqual(trainer.args.loss_type, "hinge") + self.assertEqual(trainer.args.dataset_num_proc, 4) + + def test_orpo(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = ORPOConfig( + tmp_dir, + max_length=256, + max_prompt_length=64, + max_completion_length=64, + beta=0.5, + disable_dropout=False, + label_pad_token_id=-99, + padding_value=-99, + truncation_mode="keep_start", + # generate_during_eval=True, # ignore this one, it requires wandb + is_encoder_decoder=True, + model_init_kwargs={"trust_remote_code": True}, + dataset_num_proc=4, + ) + trainer = ORPOTrainer( + model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer + ) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.max_prompt_length, 64) + self.assertEqual(trainer.args.max_completion_length, 64) + self.assertEqual(trainer.args.beta, 0.5) + self.assertEqual(trainer.args.disable_dropout, False) + self.assertEqual(trainer.args.label_pad_token_id, -99) + + def test_reward(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = RewardConfig( + tmp_dir, + max_length=256, + dataset_num_proc=4, + center_rewards_coefficient=0.1, + ) + trainer = RewardTrainer( + model=model, + args=training_args, + train_dataset=dataset, + processing_class=tokenizer, + ) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.dataset_num_proc, 4) + self.assertEqual(trainer.args.center_rewards_coefficient, 0.1) + + def test_sft(self): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = SFTConfig( + tmp_dir, + dataset_text_field="dummy_text_field", + packing=True, + max_length=256, + dataset_num_proc=4, + neftune_noise_alpha=0.1, + model_init_kwargs={"trust_remote_code": True}, + dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True}, + eval_packing=True, + ) + trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) + self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") + self.assertEqual(trainer.args.packing, True) + self.assertEqual(trainer.args.max_length, 256) + self.assertEqual(trainer.args.dataset_num_proc, 4) + self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) + self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) + self.assertIn("append_concat_token", trainer.args.dataset_kwargs) + self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True) + self.assertEqual(trainer.args.eval_packing, True) + + @parameterized.expand([(False,), (True,)]) + def test_xpo(self, alpha_list): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id) + ref_model = AutoModelForCausalLM.from_pretrained(model_id) + reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1) + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + tmp_dir, + alpha=0.5 if not alpha_list else [0.5, 0.6], + ) + trainer = XPOTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + ref_model=ref_model, + reward_model=reward_model, + train_dataset=dataset, + ) + self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6]) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1ca0765e0a1ec0931d36e8441e61f0658dee7a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,618 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap +import unittest +from io import StringIO +from unittest.mock import patch + +import numpy as np +import torch +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import ModelConfig +from trl.trainer import compute_accuracy +from trl.trainer.utils import ( + DataCollatorForChatML, + batch_generation, + decode_and_strip_padding, + flush_left, + flush_right, + generate_model_card, + get_peft_config, + pad, + print_prompt_completions_sample, + selective_log_softmax, +) + +from .testing_utils import require_rich + + +if is_peft_available(): + from peft import LoraConfig + + +class TestPad(unittest.TestCase): + def test_pad_1_dim_left(self): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5]) + output = pad((x, y), padding_value=0, padding_side="left") + expected = torch.tensor([[1, 2, 3], [0, 4, 5]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_1_dim_right(self): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor([[1, 2, 3], [4, 5, 0]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_left(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6]]) + output = pad((x, y), padding_value=0, padding_side="left") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[0, 0], [5, 6]], + ] + ) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_right(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6]]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[5, 6], [0, 0]], + ] + ) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_2_dim_right_multidim(self): + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5]]) + output = pad((x, y), padding_value=0, padding_side="right") + expected = torch.tensor( + [ + [[1, 2], [3, 4]], + [[5, 0], [0, 0]], + ] + ) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_to_multiple_of_1(self): + x = torch.tensor([1, 2, 3]) + y = torch.tensor([4, 5]) + # Max length is 3, pad to multiple of 4 + output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) + expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_to_multiple_of_2(self): + x = torch.tensor([1, 2, 3, 4, 5]) + y = torch.tensor([6, 7, 8]) + # Max length is 3, pad to multiple of 4 + output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) + expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_to_multiple_of_side_left(self): + x = torch.tensor([1, 2, 3, 4, 5]) + y = torch.tensor([6, 7, 8]) + # Max length is 3, pad to multiple of 4 + output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) + expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]]) + self.assertTrue(torch.equal(output, expected)) + + def test_pad_to_multiple_of_no_extra_padding(self): + x = torch.tensor([1, 2, 3, 4]) + y = torch.tensor([5, 6, 7, 8]) + # Already multiple of 4 + output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) + expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + self.assertTrue(torch.equal(output, expected)) + + +@require_peft +class TestGetPEFTConfig(unittest.TestCase): + def test_create_peft_config_use_peft_false(self): + """Test that when use_peft is False, the function returns None.""" + model_args = ModelConfig(use_peft=False) + peft_config = get_peft_config(model_args) + self.assertIsNone(peft_config) + + def test_create_peft_config_use_peft_true(self): + """Test that when use_peft is True, the function returns a LoraConfig object.""" + # Provide non-default values to the model config for testing + peft_kwargs = { + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.1, + "lora_task_type": "SEQ_CLS", + "use_rslora": True, + "lora_target_modules": ["up_proj", "down_proj"], + "lora_modules_to_save": ["up_proj"], + } + model_args = ModelConfig(use_peft=True, **peft_kwargs) + peft_config = get_peft_config(model_args) + self.assertTrue(isinstance(peft_config, LoraConfig)) + for arg, value in peft_kwargs.items(): + # Test that lists of modules are converted to sets + if arg == "lora_target_modules": + value = set(value) + # Rename the argument to match the LoraConfig attribute name + if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]: + arg = arg[len("lora_") :] if arg.startswith("lora_") else arg + + self.assertEqual(getattr(peft_config, arg), value) + + +class TestDecodeAndStripPadding(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + def test_example_with_padding(self): + inputs = self.tokenizer(["Hello world", "Hello"], padding=True, return_tensors="pt") + decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) + self.assertEqual(decoded, ["Hello world", "Hello"]) + + def test_example_without_padding(self): + inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt") + decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) + self.assertEqual(decoded, ["Hello", "Hello"]) + + +class TestGenerateModelCard(unittest.TestCase): + def test_full(self): + model_card = generate_model_card( + base_model="username/my_base_model", + model_name="my_model", + hub_model_id="username/my_hub_model", + dataset_name="username/my_dataset", + tags=["trl", "trainer-tag"], + wandb_url="https://wandb.ai/username/project_id/runs/abcd1234", + comet_url="https://www.comet.com/username/project_id/experiment_id", + trainer_name="My Trainer", + trainer_citation="@article{my_trainer, ...}", + paper_title="My Paper", + paper_id="1234.56789", + ) + card_text = str(model_card) + self.assertIn("[username/my_base_model](https://huggingface.co/username/my_base_model)", card_text) + self.assertIn("my_model", card_text) + self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) + self.assertIn("datasets: username/my_dataset", card_text) + self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text) + self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text) + self.assertIn("My Trainer", card_text) + self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text) + self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text) + + def test_val_none(self): + model_card = generate_model_card( + base_model=None, + model_name="my_model", + hub_model_id="username/my_hub_model", + dataset_name=None, + tags=[], + wandb_url=None, + comet_url=None, + trainer_name="My Trainer", + trainer_citation=None, + paper_title=None, + paper_id=None, + ) + card_text = str(model_card) + self.assertIn("my_model", card_text) + self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) + self.assertIn("My Trainer", card_text) + + +class TestDataCollatorForChatML(unittest.TestCase): + def setUp(self): + # Initialize the tokenizer + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Define token IDs + self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1 + self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2 + # Token ID for "true", the last assistant's response in the example: + self.ignore_index = -100 + self.max_length = 1024 + self.messages_key = "messages" + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset.to_list() + + # Initialize the data collator + self.collator = DataCollatorForChatML( + tokenizer=self.tokenizer, + max_length=self.max_length, + ignore_index=self.ignore_index, + ) + + def test_data_collator_for_chatml(self): + # Process the data + data = self.collator(self.examples) + + # Verify basic shapes and types + self.assertIn("input_ids", data) + self.assertIn("attention_mask", data) + self.assertIn("labels", data) + self.assertIn("prompts", data) + self.assertIn("prompt_attention_mask", data) + + # Decode input_ids and labels for verification + input_ids = data["input_ids"][0].tolist() + labels = data["labels"][0].tolist() + prompt_only = data["prompts"][0].tolist() + + # Get the last assistant's response for comparison + last_message = self.examples[0][self.messages_key][-1] + self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant") + last_assistant_response = last_message["content"] + + # Verify that input_ids contain both prompt and response + decoded_input = self.tokenizer.decode(input_ids) + self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response") + + # Verify that prompts only contain the conversation up to the last response + decoded_prompt = self.tokenizer.decode(prompt_only) + self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response") + + # Verify labels are -100 for non-assistant parts + prompt_length = len(prompt_only) + self.assertTrue( + all(label == self.ignore_index for label in labels[:prompt_length]), + "Labels should be ignore_index for prompt tokens", + ) + + # Verify labels match assistant response after prompt + # Add a filter to remove any trailing tokens after the first <|im_end|> + last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token + last_assistant_response_tokens = self.tokenizer.encode( + last_assistant_response_with_end, add_special_tokens=False + ) + + response_labels = [] + for label in labels[prompt_length:]: + if label == self.ignore_index: + continue + response_labels.append(label) + if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): + break + self.assertEqual( + response_labels, + last_assistant_response_tokens, + "Labels should match assistant response tokens", + ) + + # Verify there isn't a generation prompt at the end + generation_prompt = "<|im_start|>assistant" + self.assertFalse( + decoded_input.strip().endswith(generation_prompt), + f"Input should not end with generation prompt '{generation_prompt}'", + ) + + self.assertEqual( + response_labels, + last_assistant_response_tokens, + "Labels should match assistant response tokens", + ) + + +class TestBatchGeneration(unittest.TestCase): + def setUp(self): + # Initialize the tokenizer + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + self.generation_config = GenerationConfig( + max_new_tokens=128, + temperature=0.5, + do_sample=True, + top_k=0, + pad_token_id=self.tokenizer.pad_token_id, + ) + + # Example input + dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train") + self.examples = dataset["messages"] + self.mini_batch_size = 3 + + def test_mini_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"] + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, self.mini_batch_size, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + self.assertGreater(max_length_query, context_length) + self.assertEqual(query_responses.shape, (bs, max_length_query)) + self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + + def test_single_batch_generation(self): + batch = [ + self.tokenizer.apply_chat_template(example[:-1], add_generation_prompt=True, tokenize=False) + for example in self.examples + ] + queries = self.tokenizer(batch, padding=True, return_tensors="pt")["input_ids"] + bs, context_length = queries.shape + + query_responses, logits = batch_generation( + self.model, queries, bs, self.tokenizer.pad_token_id, self.generation_config + ) + + max_length_query = query_responses.shape[1] + max_length_logits = max_length_query - context_length + + self.assertGreater(max_length_query, context_length) + self.assertEqual(query_responses.shape, (bs, max_length_query)) + self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + + +class TestComputeAccuracy(unittest.TestCase): + def test_token_classification_task(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[0, 1], [1, 0]]), + ) + expected_accuracy = 0.5 # 2 matches, 2 mismatches + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_token_classification_task_with_ignored_tokens_0(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 0], [1, -100]]), + ) + expected_accuracy = 1.0 # All non-ignored tokens match + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_token_classification_task_with_ignored_tokens_1(self): + eval_pred = ( + np.array( + [ + [[0.1, 0.9], [0.8, 0.2]], # Batch 1 + [[0.3, 0.7], [0.6, 0.4]], # Batch 2 + ] + ), + np.array([[1, 1], [0, -100]]), + ) + expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored + result = compute_accuracy(eval_pred) + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + + def test_rewards_comparison_task(self): + eval_pred = ( + np.array( + [ + [0.9, 0.1], # Batch 1 + [0.6, 0.4], # Batch 2 + [0.5, 0.5], # Batch 3 (equal) + ] + ), + np.array([0, 1, 1]), + ) + expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored) + + with self.assertWarns(UserWarning) as cm: + result = compute_accuracy(eval_pred) + + self.assertAlmostEqual(result["accuracy"], expected_accuracy) + expected_warning = ( + "There are 1 out of 3 instances where the predictions for both options are equal. " + "These instances are ignored in the accuracy computation." + ) + self.assertEqual(str(cm.warning), expected_warning) + + +class TestFlushLeft(unittest.TestCase): + def test_basic_case(self): + mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) + tensor1 = torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 0, 0]]) + tensor2 = torch.tensor([[0, 0, 7, 8, 9], [0, 10, 11, 0, 0]]) + new_mask, new_tensor1, new_tensor2 = flush_left(mask, tensor1, tensor2) + + expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) + expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]]) + expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + + def test_single_row(self): + mask = torch.tensor([[0, 0, 1, 1]]) + tensor1 = torch.tensor([[0, 0, 2, 3]]) + new_mask, new_tensor1 = flush_left(mask, tensor1) + + expected_mask = torch.tensor([[1, 1]]) + expected_tensor1 = torch.tensor([[2, 3]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_shift_needed(self): + mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]]) + tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]]) + new_mask, new_tensor1 = flush_left(mask, tensor1) + + expected_mask = torch.tensor([[1, 1], [1, 0]]) + expected_tensor1 = torch.tensor([[5, 6], [7, 0]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_tensors(self): + mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) + new_mask = flush_left(mask) + expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) + self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestFlushRight(unittest.TestCase): + def test_basic_case(self): + mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) + tensor1 = torch.tensor([[2, 3, 4, 0, 0], [0, 0, 5, 6, 0]]) + tensor2 = torch.tensor([[7, 8, 9, 0, 0], [0, 0, 10, 11, 0]]) + new_mask, new_tensor1, new_tensor2 = flush_right(mask, tensor1, tensor2) + + expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) + expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]]) + expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + + def test_single_row(self): + mask = torch.tensor([[1, 1, 0, 0]]) + tensor1 = torch.tensor([[2, 3, 0, 0]]) + new_mask, new_tensor1 = flush_right(mask, tensor1) + + expected_mask = torch.tensor([[1, 1]]) + expected_tensor1 = torch.tensor([[2, 3]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_shift_needed(self): + mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]]) + tensor1 = torch.tensor([[0, 0, 5, 6], [0, 0, 0, 7]]) + new_mask, new_tensor1 = flush_right(mask, tensor1) + + expected_mask = torch.tensor([[1, 1], [0, 1]]) + expected_tensor1 = torch.tensor([[5, 6], [0, 7]]) + + self.assertTrue(torch.equal(new_mask, expected_mask)) + self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + + def test_no_tensors(self): + mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) + new_mask = flush_right(mask) + expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) + self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestSelectiveLogSoftmax(unittest.TestCase): + @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) + def test_selective_log_softmax(self, dtype): + """Test selective_log_softmax with logits of different dtypes""" + vocab_size = 1024 + batch_size = 4 + seq_len = 32 + + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len)) + logits = torch.randn(batch_size, seq_len, vocab_size, dtype=dtype) + + expected_output = torch.gather(logits.log_softmax(-1), dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) + actual_output = selective_log_softmax(logits, input_ids) + + if dtype in [torch.float16, torch.bfloat16]: + # half-precision dtypes fall back to an exact method + self.assertTrue(torch.equal(actual_output, expected_output)) + else: + torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) + + +@require_rich +class TestPrintPromptCompletionsSample(unittest.TestCase): + @patch("sys.stdout", new_callable=StringIO) + def test_print_output(self, mock_stdout): + prompts = ["The sky is", "The sun is"] + completions = [" blue.", " in the sky."] + rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} + advantages = [0.987, 0.654] + step = 42 + + print_prompt_completions_sample(prompts, completions, rewards, advantages, step) + + output = mock_stdout.getvalue() + + expected_output = textwrap.dedent("""\ + ╭──────────────────────────── Step 42 ─────────────────────────────╮ + │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ + │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ + │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ + │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │ + │ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │ + │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │ + │ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │ + ╰──────────────────────────────────────────────────────────────────╯ + """) + self.assertEqual(output, expected_output) + + @patch("sys.stdout", new_callable=StringIO) + def test_num_samples(self, mock_stdout): + prompts = ["A", "B"] + completions = ["1", "2"] + rewards = {"Score": [0.1, 0.2]} + advantages = [0.3, 0.4] + step = 10 + + print_prompt_completions_sample(prompts, completions, rewards, advantages, step, num_samples=1) + output = mock_stdout.getvalue() + + possible_outputs = [ + textwrap.dedent("""\ + ╭────────────────── Step 10 ──────────────────╮ + │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ + │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ + │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ + │ │ A │ 1 │ 0.10 │ 0.30 │ │ + │ └────────┴────────────┴───────┴───────────┘ │ + ╰─────────────────────────────────────────────╯ + """), + textwrap.dedent("""\ + ╭────────────────── Step 10 ──────────────────╮ + │ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │ + │ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │ + │ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │ + │ │ B │ 2 │ 0.20 │ 0.40 │ │ + │ └────────┴────────────┴───────┴───────────┘ │ + ╰─────────────────────────────────────────────╯ + """), + ] + self.assertIn(output, possible_outputs) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py new file mode 100644 index 0000000000000000000000000000000000000000..63f2ea3e230a03e11ca426487856a7bc0abbfeaa --- /dev/null +++ b/tests/test_vllm_client_server.py @@ -0,0 +1,336 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import signal +import subprocess +import unittest + +import psutil +import pytest +from transformers import AutoModelForCausalLM +from transformers.testing_utils import require_torch_multi_accelerator, torch_device + +from trl.extras.vllm_client import VLLMClient +from trl.scripts.vllm_serve import chunk_list + +from .testing_utils import require_3_accelerators + + +class TestChunkList(unittest.TestCase): + def test_even_split(self): + self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 2), [[1, 2, 3], [4, 5, 6]]) + + def test_uneven_split(self): + self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 4), [[1, 2], [3, 4], [5], [6]]) + + def test_more_chunks_than_elements(self): + self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 8), [[1], [2], [3], [4], [5], [6], [], []]) + + def test_n_equals_len(self): + self.assertEqual(chunk_list([1, 2, 3], 3), [[1], [2], [3]]) + + def test_n_is_1(self): + self.assertEqual(chunk_list([1, 2, 3], 1), [[1, 2, 3]]) + + def test_single_element_list(self): + self.assertEqual(chunk_list([42], 2), [[42], []]) + + def test_any_dtype(self): + self.assertEqual( + chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2), + [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]], + ) + + +@pytest.mark.slow +@require_torch_multi_accelerator +class TestVLLMClientServer(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is equal to the number of prompts + self.assertEqual(len(outputs), len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is 2 times the number of prompts + self.assertEqual(len(outputs), 2 * len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in outputs: + self.assertLessEqual(len(seq), 32) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() + + +# Same as above but using base_url to instantiate the client. +@pytest.mark.slow +@require_torch_multi_accelerator +class TestVLLMClientServerBaseURL(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env + ) + + # Initialize the client + cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is equal to the number of prompts + self.assertEqual(len(outputs), len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + def test_generate_with_params(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is 2 times the number of prompts + self.assertEqual(len(outputs), 2 * len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + # Check that the length of the generated sequences is less than or equal to 32 + for seq in outputs: + self.assertLessEqual(len(seq), 32) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() + + +@pytest.mark.slow +@require_3_accelerators +class TestVLLMClientServerTP(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id, "--tensor_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240) + cls.client.init_communicator() + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is equal to the number of prompts + self.assertEqual(len(outputs), len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() + + +@pytest.mark.slow +@require_3_accelerators +class TestVLLMClientServerDP(unittest.TestCase): + model_id = "Qwen/Qwen2.5-1.5B" + + @classmethod + def setUpClass(cls): + # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" + env = os.environ.copy() + VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" + env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2 + + # Start the server process + cls.server_process = subprocess.Popen( + ["trl", "vllm-serve", "--model", cls.model_id, "--data_parallel_size", "2"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + ) + + # Initialize the client + cls.client = VLLMClient(connection_timeout=240) + + def test_generate(self): + prompts = ["Hello, AI!", "Tell me a joke"] + outputs = self.client.generate(prompts) + + # Check that the output is a list + self.assertIsInstance(outputs, list) + + # Check that the number of generated sequences is equal to the number of prompts + self.assertEqual(len(outputs), len(prompts)) + + # Check that the generated sequences are lists of integers + for seq in outputs: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + + def test_update_model_params(self): + model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) + self.client.update_model_params(model) + + def test_reset_prefix_cache(self): + # Test resetting the prefix cache + self.client.reset_prefix_cache() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + + # Close the client + cls.client.close_communicator() + + # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to + # kill the server process and its children explicitly. + parent = psutil.Process(cls.server_process.pid) + children = parent.children(recursive=True) + for child in children: + child.send_signal(signal.SIGTERM) + cls.server_process.terminate() + cls.server_process.wait() diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..02d6fb6fd7d9d0d4b0a3a6ef4deddb76d91d0cbd --- /dev/null +++ b/tests/test_xpo_trainer.py @@ -0,0 +1,223 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import XPOConfig, XPOTrainer + +from .testing_utils import RandomPairwiseJudge, require_llm_blender + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestXPOTrainer(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.model_id, num_labels=1) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_xpo_trainer_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_model_and_peft_config(self): + model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + model = get_peft_model(self.model, model_lora_config) + # we want only the "train adapter" to be trained + lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=model, + reward_model=self.reward_model, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_train_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_pre_pefted_model_implicit_ref(self): + lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + peft_model_instance = get_peft_model(self.model, lora_config) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=1, + max_steps=2, + learning_rate=5.0e-7, + eval_strategy="no", + report_to="none", + remove_unused_columns=False, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"] + + trainer = XPOTrainer( + model=peft_model_instance, + ref_model=None, + reward_model=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset, + ) + + trainer.train() + + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_llm_blender + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_xpo_trainer_judge_training(self, config_name): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + judge = RandomPairwiseJudge() + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + judge=judge, + args=training_args, + processing_class=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) diff --git a/tests/testing_constants.py b/tests/testing_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a4c6755c5d813083855221a51e927c394be03df6 --- /dev/null +++ b/tests/testing_constants.py @@ -0,0 +1,18 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" +CI_HUB_USER_FULL_NAME = "Dummy User" + +CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" diff --git a/tests/testing_utils.py b/tests/testing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0211fd362c877f75a7151550a57da0e66b37452c --- /dev/null +++ b/tests/testing_utils.py @@ -0,0 +1,126 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import torch +from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available +from transformers.testing_utils import torch_device +from transformers.utils import is_rich_available + +from trl import BaseBinaryJudge, BasePairwiseJudge +from trl.import_utils import ( + is_diffusers_available, + is_joblib_available, + is_llm_blender_available, + is_mergekit_available, + is_vllm_available, +) + + +# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use +# in our test suite. We therefore need to implement our own version of this function. +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. + """ + return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) + + +def require_comet(test_case): + """ + Decorator marking a test that requires Comet. Skips the test if Comet is not available. + """ + return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case) + + +def require_diffusers(test_case): + """ + Decorator marking a test that requires diffusers. Skips the test if diffusers is not available. + """ + return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case) + + +def require_llm_blender(test_case): + """ + Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available. + """ + return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case) + + +def require_mergekit(test_case): + """ + Decorator marking a test that requires mergekit. Skips the test if mergekit is not available. + """ + return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case) + + +def require_rich(test_case): + """ + Decorator marking a test that requires rich. Skips the test if rich is not available. + """ + return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case) + + +def require_sklearn(test_case): + """ + Decorator marking a test that requires sklearn. Skips the test if sklearn is not available. + """ + return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case) + + +def require_vllm(test_case): + """ + Decorator marking a test that requires vllm. Skips the test if vllm is not available. + """ + return unittest.skipUnless(is_vllm_available(), "test requires vllm")(test_case) + + +def require_no_wandb(test_case): + """ + Decorator marking a test that requires no wandb. Skips the test if wandb is available. + """ + return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case) + + +def require_3_accelerators(test_case): + """ + Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available. + """ + torch_accelerator_module = getattr(torch, torch_device, torch.cuda) + return unittest.skipUnless( + torch_accelerator_module.device_count() > 3, f"test requires at least 3 {torch_device}s" + )(test_case) + + +class RandomBinaryJudge(BaseBinaryJudge): + """ + Random binary judge, for testing purposes. + """ + + def judge(self, prompts, completions, gold_completions=None, shuffle_order=True): + return [random.choice([0, 1, -1]) for _ in range(len(prompts))] + + +class RandomPairwiseJudge(BasePairwiseJudge): + """ + Random pairwise judge, for testing purposes. + """ + + def judge(self, prompts, completions, shuffle_order=True, return_scores=False): + if not return_scores: + return [random.randint(0, len(completion) - 1) for completion in completions] + else: + return [random.random() for _ in range(len(prompts))] diff --git a/trl/__init__.py b/trl/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c793961765708a27bd836d3f3ee2fee7261ed5c8 --- /dev/null +++ b/trl/__init__.py @@ -0,0 +1,220 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__version__ = "0.19.0.dev0" + +from typing import TYPE_CHECKING + +from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available + + +_import_structure = { + "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], + "data_utils": [ + "apply_chat_template", + "extract_prompt", + "is_conversational", + "maybe_apply_chat_template", + "maybe_convert_to_chatml", + "maybe_extract_prompt", + "maybe_unpair_preference_dataset", + "pack_dataset", + "pack_examples", + "truncate_dataset", + "unpair_preference_dataset", + ], + "environment": ["TextEnvironment", "TextHistory"], + "extras": ["BestOfNSampler"], + "models": [ + "SUPPORTED_ARCHITECTURES", + "AutoModelForCausalLMWithValueHead", + "AutoModelForSeq2SeqLMWithValueHead", + "PreTrainedModelWrapper", + "create_reference_model", + "setup_chat_format", + ], + "trainer": [ + "AlignPropConfig", + "AlignPropTrainer", + "AllTrueJudge", + "BaseBinaryJudge", + "BaseJudge", + "BasePairwiseJudge", + "BaseRankJudge", + "BCOConfig", + "BCOTrainer", + "CPOConfig", + "CPOTrainer", + "DataCollatorForCompletionOnlyLM", + "DPOConfig", + "DPOTrainer", + "FDivergenceConstants", + "FDivergenceType", + "GKDConfig", + "GKDTrainer", + "GRPOConfig", + "GRPOTrainer", + "HfPairwiseJudge", + "IterativeSFTConfig", + "IterativeSFTTrainer", + "KTOConfig", + "KTOTrainer", + "LogCompletionsCallback", + "MergeModelCallback", + "ModelConfig", + "NashMDConfig", + "NashMDTrainer", + "OnlineDPOConfig", + "OnlineDPOTrainer", + "OpenAIPairwiseJudge", + "ORPOConfig", + "ORPOTrainer", + "PairRMJudge", + "PPOConfig", + "PPOTrainer", + "PRMConfig", + "PRMTrainer", + "RewardConfig", + "RewardTrainer", + "RLOOConfig", + "RLOOTrainer", + "SFTConfig", + "SFTTrainer", + "WinRateCallback", + "XPOConfig", + "XPOTrainer", + ], + "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], +} + +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["models"].extend( + [ + "DDPOPipelineOutput", + "DDPOSchedulerOutput", + "DDPOStableDiffusionPipeline", + "DefaultDDPOStableDiffusionPipeline", + ] + ) + _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) + +if TYPE_CHECKING: + from .data_utils import ( + apply_chat_template, + extract_prompt, + is_conversational, + maybe_apply_chat_template, + maybe_convert_to_chatml, + maybe_extract_prompt, + maybe_unpair_preference_dataset, + pack_dataset, + pack_examples, + truncate_dataset, + unpair_preference_dataset, + ) + from .environment import TextEnvironment, TextHistory + from .extras import BestOfNSampler + from .models import ( + SUPPORTED_ARCHITECTURES, + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, + create_reference_model, + setup_chat_format, + ) + from .scripts import ScriptArguments, TrlParser, init_zero_verbose + from .trainer import ( + AlignPropConfig, + AlignPropTrainer, + AllTrueJudge, + BaseBinaryJudge, + BaseJudge, + BasePairwiseJudge, + BaseRankJudge, + BCOConfig, + BCOTrainer, + CPOConfig, + CPOTrainer, + DataCollatorForCompletionOnlyLM, + DPOConfig, + DPOTrainer, + FDivergenceConstants, + FDivergenceType, + GKDConfig, + GKDTrainer, + GRPOConfig, + GRPOTrainer, + HfPairwiseJudge, + IterativeSFTConfig, + IterativeSFTTrainer, + KTOConfig, + KTOTrainer, + LogCompletionsCallback, + MergeModelCallback, + ModelConfig, + NashMDConfig, + NashMDTrainer, + OnlineDPOConfig, + OnlineDPOTrainer, + OpenAIPairwiseJudge, + ORPOConfig, + ORPOTrainer, + PairRMJudge, + PPOConfig, + PPOTrainer, + PRMConfig, + PRMTrainer, + RewardConfig, + RewardTrainer, + RLOOConfig, + RLOOTrainer, + SFTConfig, + SFTTrainer, + WinRateCallback, + XPOConfig, + XPOTrainer, + ) + from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .models import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) + from .trainer import DDPOConfig, DDPOTrainer + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": __version__}, + ) diff --git a/trl/accelerate_configs/fsdp1.yaml b/trl/accelerate_configs/fsdp1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c01b0b567bc93bf87ec136ea975b3793d273a45c --- /dev/null +++ b/trl/accelerate_configs/fsdp1.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/fsdp2.yaml b/trl/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af498f3eced9c2434b80113f2f22d40395e0ab8a --- /dev/null +++ b/trl/accelerate_configs/fsdp2.yaml @@ -0,0 +1,25 @@ +# Requires accelerate 1.7.0 or higher +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/multi_gpu.yaml b/trl/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..15dad9be3ba44f7c934e1ecab98a93cb83cbc79a --- /dev/null +++ b/trl/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/single_gpu.yaml b/trl/accelerate_configs/single_gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebd00a067118e56f3d63ab0f24827cfea21b24b9 --- /dev/null +++ b/trl/accelerate_configs/single_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: "NO" +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/zero1.yaml b/trl/accelerate_configs/zero1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d5b5f782fb30f9fcbcc8fc58262f09eaf2e10368 --- /dev/null +++ b/trl/accelerate_configs/zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/zero2.yaml b/trl/accelerate_configs/zero2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..239b14ac3a9ae8de73122d1154bf0d71903dc15f --- /dev/null +++ b/trl/accelerate_configs/zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/zero3.yaml b/trl/accelerate_configs/zero3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5a1201f8a2ee8706b63f0f80c664a1fc61a7d9d --- /dev/null +++ b/trl/accelerate_configs/zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/cli.py b/trl/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..9927687ce1ca3d775832d24efb1d2ac6fc22c909 --- /dev/null +++ b/trl/cli.py @@ -0,0 +1,137 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.resources as resources +import os +import sys +import warnings + +from accelerate.commands.launch import launch_command, launch_command_parser + +from .scripts.dpo import make_parser as make_dpo_parser +from .scripts.env import print_env +from .scripts.grpo import make_parser as make_grpo_parser +from .scripts.kto import make_parser as make_kto_parser +from .scripts.sft import make_parser as make_sft_parser +from .scripts.utils import TrlParser +from .scripts.vllm_serve import main as vllm_serve_main +from .scripts.vllm_serve import make_parser as make_vllm_serve_parser + + +def main(): + parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False) + + # Add the subparsers + subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) + + # Add the subparsers for every script + make_dpo_parser(subparsers) + subparsers.add_parser("env", help="Print the environment information") + make_grpo_parser(subparsers) + make_kto_parser(subparsers) + make_sft_parser(subparsers) + make_vllm_serve_parser(subparsers) + + # Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser. + # Duplicates may occur if the same argument is provided in both the config file and CLI. + # For example: launch_args = `["--num_processes", "4", "--num_processes", "8"]`. + # Deduplication and precedence (CLI over config) are handled later by launch_command_parser. + args, launch_args = parser.parse_args_and_config(return_remaining_strings=True) + + # Replace `--accelerate_config foo` with `--config_file trl/accelerate_configs/foo.yaml` if it is present in the + # launch_args. It allows the user to use predefined accelerate configs from the `trl` package. + if "--accelerate_config" in launch_args: + # Get the index of the '--accelerate_config' argument and the corresponding config name + config_index = launch_args.index("--accelerate_config") + config_name = launch_args[config_index + 1] + + # If the config_name correspond to a path in the filesystem, we don't want to override it + if os.path.isfile(config_name): + accelerate_config_path = config_name + elif resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml").exists(): + # Get the predefined accelerate config path from the package resources + accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml") + else: + raise ValueError( + f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. " + "Please provide a valid config name or a path to a config file." + ) + + # Remove '--accelerate_config' and its corresponding config name + launch_args.pop(config_index) + launch_args.pop(config_index) + + # Insert '--config_file' and the absolute path to the front of the list + launch_args = ["--config_file", str(accelerate_config_path)] + launch_args + + if args.command == "dpo": + # Get the default args for the launch command + dpo_training_script = resources.files("trl.scripts").joinpath("dpo.py") + args = launch_command_parser().parse_args([str(dpo_training_script)]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" + launch_command(args) # launch training + + elif args.command == "env": + print_env() + + elif args.command == "grpo": + # Get the default args for the launch command + grpo_training_script = resources.files("trl.scripts").joinpath("grpo.py") + args = launch_command_parser().parse_args([str(grpo_training_script)]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "grpo" + launch_command(args) # launch training + + elif args.command == "kto": + # Get the default args for the launch command + kto_training_script = resources.files("trl.scripts").joinpath("kto.py") + args = launch_command_parser().parse_args([str(kto_training_script)]) + + # Feed the args to the launch command + args.training_script_args = sys.argv[2:] # remove "trl" and "kto" + launch_command(args) # launch training + + elif args.command == "sft": + # Get the path to the training script + sft_training_script = resources.files("trl.scripts").joinpath("sft.py") + + # This simulates running: `accelerate launch sft.py `. + # Note that the training script args may include launch-related arguments (e.g., `--num_processes`), + # but we rely on the script to ignore any that don't apply to it. + training_script_args = sys.argv[2:] # Remove "trl" and "sft" + args = launch_command_parser().parse_args(launch_args + [str(sft_training_script)] + training_script_args) + launch_command(args) # launch training + + elif args.command == "vllm-serve": + (script_args,) = parser.parse_args_and_config() + + # Known issue: Using DeepSpeed with tensor_parallel_size=1 and data_parallel_size>1 may cause a crash when + # launched via the CLI. Suggest running the module directly. + # More information: https://github.com/vllm-project/vllm/issues/17079 + if script_args.tensor_parallel_size == 1 and script_args.data_parallel_size > 1: + warnings.warn( + "Detected configuration: tensor_parallel_size=1 and data_parallel_size>1. This setup is known to " + "cause a crash when using the `trl vllm-serve` CLI entry point. As a workaround, please run the " + "server using the module path instead: `python -m trl.scripts.vllm_serve`", + RuntimeWarning, + ) + + vllm_serve_main(script_args) + + +if __name__ == "__main__": + main() diff --git a/trl/core.py b/trl/core.py new file mode 100644 index 0000000000000000000000000000000000000000..bd8733bedeb313fea48aecd0c67aa7d89d5132d4 --- /dev/null +++ b/trl/core.py @@ -0,0 +1,159 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import warnings +from collections.abc import Mapping +from contextlib import contextmanager +from typing import Optional, Union + +import numpy as np +import torch +from transformers import is_torch_npu_available, is_torch_xpu_available + + +def flatten_dict(nested: dict, sep: str = "/") -> dict: + """Flatten dictionary and concatenate nested keys with separator.""" + + def recurse(nest: dict, prefix: str, into: dict) -> None: + for k, v in nest.items(): + if sep in k: + raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") + if isinstance(v, Mapping): + recurse(v, prefix + k + sep, into) + else: + into[prefix + k] = v + + flat = {} + recurse(nested, "", flat) + return flat + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +class LengthSampler: + """ + Samples a length + """ + + def __init__(self, min_value: int, max_value: int): + self.values = list(range(min_value, max_value)) + + def __call__(self) -> int: + return np.random.choice(self.values) + + +class PPODecorators: + optimize_device_cache = False + + @classmethod + @contextmanager + def empty_device_cache(cls): + yield + if cls.optimize_device_cache: + if is_torch_xpu_available(): + gc.collect() + torch.xpu.empty_cache() + gc.collect() + elif is_torch_npu_available(): + gc.collect() + torch.npu.empty_cache() + gc.collect() + elif torch.cuda.is_available(): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + +def randn_tensor( + shape: Union[tuple, list], + generator: Optional[Union[list[torch.Generator], torch.Generator]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, +) -> torch.Tensor: + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + warnings.warn( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device.", + UserWarning, + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents diff --git a/trl/data_utils.py b/trl/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe286141b6a82b6648b29e82ee037a9240549b3e --- /dev/null +++ b/trl/data_utils.py @@ -0,0 +1,765 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections import defaultdict, deque +from collections.abc import Sequence +from itertools import takewhile +from typing import Any, Callable, Optional, TypeVar, Union + +import numpy as np +import pyarrow as pa +import pyarrow.compute as pc +import pyarrow.types +from datasets import Dataset, DatasetDict +from transformers import PreTrainedTokenizerBase + + +DatasetType = TypeVar("DatasetType", Dataset, DatasetDict) + + +def is_conversational(example: dict[str, Any]) -> bool: + r""" + Check if the example is in a conversational format. + + Args: + example (`dict[str, Any]`): + A single data entry of a dataset. The example can have different keys depending on the + dataset type. + + Returns: + `bool`: + `True` if the data is in a conversational format, `False` otherwise. + + Examples: + + ```python + >>> example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]} + >>> is_conversational(example) + True + >>> example = {"prompt": "The sky is"}) + >>> is_conversational(example) + False + ``` + """ + supported_keys = ["prompt", "chosen", "rejected", "completion", "messages"] + example_keys = {key for key in example.keys() if key in supported_keys} + + # It must have one of the supported keys + if example_keys: + key = example_keys.pop() # take the first supported key + maybe_messages = example[key] + # It must be a list of messages, + if isinstance(maybe_messages, list): + maybe_message = maybe_messages[0] + # Each message must a list of dictionaries with keys "role" and "content" + if isinstance(maybe_message, dict) and "role" in maybe_message and "content" in maybe_message: + return True + + return False + + +def apply_chat_template( + example: dict[str, list[dict[str, str]]], + tokenizer: PreTrainedTokenizerBase, + tools: Optional[list[Union[dict, Callable]]] = None, +) -> dict[str, str]: + r""" + Apply a chat template to a conversational example along with the schema for a list of functions in `tools`. + + For more details, see [`maybe_apply_chat_template`]. + """ + # Check that the example has the correct keys + supported_keys = ["prompt", "chosen", "rejected", "completion", "messages", "label"] + example_keys = {key for key in example.keys() if key in supported_keys} + if example_keys not in [ + {"messages"}, # language modeling + {"prompt"}, # prompt-only + {"prompt", "completion"}, # prompt-completion + {"prompt", "chosen", "rejected"}, # preference + {"chosen", "rejected"}, # preference with implicit prompt + {"prompt", "completion", "label"}, # unpaired preference + ]: + raise KeyError(f"Invalid keys in the example: {example_keys}") + + # Apply the chat template to the whole conversation + if "messages" in example: + messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False) + + # Apply the chat template to the prompt, adding the generation prompt + if "prompt" in example: + last_role = example["prompt"][-1]["role"] + if last_role == "user": + add_generation_prompt = True + continue_final_message = False + elif last_role == "assistant": + add_generation_prompt = False + continue_final_message = True + else: + raise ValueError(f"Invalid role in the last message: {last_role}") + prompt = tokenizer.apply_chat_template( + example["prompt"], + tools=tools, + continue_final_message=continue_final_message, + tokenize=False, + add_generation_prompt=add_generation_prompt, + ) + + # Apply the chat template to the entire prompt + completion + if "prompt" in example: # explicit prompt and prompt-completion case + if "chosen" in example: + prompt_chosen = tokenizer.apply_chat_template( + example["prompt"] + example["chosen"], tools=tools, tokenize=False + ) + # DeepSeek-R1 inserts a token when using `add_generation_prompt`, which can cause discrepancies + # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the + # common prefix between the two. In most cases, this is a no-op. + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen))) + + chosen = prompt_chosen[len(prompt) :] + if "rejected" in example and "prompt" in example: # explicit prompt + prompt_rejected = tokenizer.apply_chat_template( + example["prompt"] + example["rejected"], tools=tools, tokenize=False + ) + # Handle DeepSeek-R1 token, see the above comment for details + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected))) + rejected = prompt_rejected[len(prompt) :] + if "completion" in example: + prompt_completion = tokenizer.apply_chat_template( + example["prompt"] + example["completion"], tools=tools, tokenize=False + ) + # Handle DeepSeek-R1 token, see the above comment for details + prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion))) + completion = prompt_completion[len(prompt) :] + else: # implicit prompt case + if "chosen" in example: + chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False) + if "rejected" in example: + rejected = tokenizer.apply_chat_template(example["rejected"], tools=tools, tokenize=False) + + # Extract the completion by removing the prompt part from the prompt-completion string + output = {} + if "messages" in example: + output["text"] = messages + if "prompt" in example: + output["prompt"] = prompt + if "chosen" in example: + output["chosen"] = chosen + if "rejected" in example: + output["rejected"] = rejected + if "completion" in example: + output["completion"] = completion + if "label" in example: + output["label"] = example["label"] + + return output + + +def maybe_apply_chat_template( + example: dict[str, list[dict[str, str]]], + tokenizer: PreTrainedTokenizerBase, + tools: Optional[list[Union[dict, Callable]]] = None, +) -> dict[str, str]: + r""" + If the example is in a conversational format, apply a chat template to it. + + Args: + example (`dict[str, list[dict[str, str]]`): + Dictionary representing a single data entry of a conversational dataset. Each data entry can have different + keys depending on the dataset type. The supported dataset types are: + + - Language modeling dataset: `"messages"`. + - Prompt-only dataset: `"prompt"`. + - Prompt-completion dataset: `"prompt"` and `"completion"`. + - Preference dataset: `"prompt"`, `"chosen"`, and `"rejected"`. + - Preference dataset with implicit prompt: `"chosen"` and `"rejected"`. + - Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`. + + For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of + messages, where each message is a dictionary with keys `"role"` and `"content"`. + tokenizer (`PreTrainedTokenizerBase`): + Tokenizer to apply the chat template with. + tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`): + A list of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect + + Returns: + `dict[str, str]`: + Formatted example with the chat template applied. + + Notes: + - This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced + by `"text"`. + + - In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. + Else, if the last role is `"assistant"`, the final message is continued. + + Example: + + ```python + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") + >>> example = { + ... "prompt": [{"role": "user", "content": "What color is the sky?"}], + ... "completion": [{"role": "assistant", "content": "It is blue."}] + ... } + >>> apply_chat_template(example, tokenizer) + {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'} + ``` + """ + if is_conversational(example): + return apply_chat_template(example, tokenizer, tools) + else: + return example + + +def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[str, list[dict[str, str]]]]: + batch_size = len(examples["chosen"]) + new_rows = { + "completion": examples["chosen"] + examples["rejected"], + "label": [True] * batch_size + [False] * batch_size, + } + if "prompt" in examples: + new_rows["prompt"] = examples["prompt"] + examples["prompt"] + return new_rows + + +def unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: + r""" + Unpair a preference dataset. + + Args: + dataset (`Dataset` or `DatasetDict`): + Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally + `"prompt"`. + num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. + + Returns: + `Dataset`: The unpaired preference dataset. + + Example: + + ```python + >>> from datasets import Dataset + >>> dataset_dict = { + ... "prompt": ["The sky is", "The sun is"] + ... "chosen": [" blue.", "in the sky."], + ... "rejected": [" green.", " in the sea."] + ... } + >>> dataset = Dataset.from_dict(dataset_dict) + >>> dataset = unpair_preference_dataset(dataset) + >>> dataset + Dataset({ + features: ['prompt', 'completion', 'label'], + num_rows: 4 + }) + >>> dataset[0] + {'prompt': 'The sky is', 'completion': ' blue.', 'label': True} + ``` + """ + return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc, desc=desc) + + +def maybe_unpair_preference_dataset( + dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None +) -> DatasetType: + r""" + Unpair a preference dataset if it is paired. + + Args: + dataset (`Dataset` or `DatasetDict`): + Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally + `"prompt"`. + num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + desc (`str` or `None`, *optional*, defaults to `None`): + Meaningful description to be displayed alongside with the progress bar while mapping examples. + + Returns: + `Dataset` or `DatasetDict`: The unpaired preference dataset if it was paired, otherwise the original dataset. + + Example: + + ```python + >>> from datasets import Dataset + >>> dataset_dict = { + ... "prompt": ["The sky is", "The sun is"] + ... "chosen": [" blue.", "in the sky."], + ... "rejected": [" green.", " in the sea."] + ... } + >>> dataset = Dataset.from_dict(dataset_dict) + >>> dataset = unpair_preference_dataset(dataset) + >>> dataset + Dataset({ + features: ['prompt', 'completion', 'label'], + num_rows: 4 + }) + >>> dataset[0] + {'prompt': 'The sky is', 'completion': ' blue.', 'label': True} + ``` + """ + if isinstance(dataset, DatasetDict): + column_names = dataset[list(dataset.keys())[0]].column_names + else: + column_names = dataset.column_names + if "chosen" in column_names and "rejected" in column_names: + return unpair_preference_dataset(dataset, num_proc=num_proc, desc=desc) + else: + return dataset + + +def extract_prompt(example: dict[str, Sequence]) -> dict[str, Sequence]: + r""" + Extracts the shared prompt from a preference data example, where the prompt is implicit within both + the chosen and rejected completions. + + For more details, see [`maybe_extract_prompt`]. + """ + for idx in range(min(len(example["chosen"]), len(example["rejected"]))): + if example["chosen"][idx] != example["rejected"][idx]: + if example["chosen"][idx - 1] == " ": # remove space before the prompt + idx -= 1 + break + return { + "prompt": example["chosen"][:idx], + "chosen": example["chosen"][idx:], + "rejected": example["rejected"][idx:], + } + + +def maybe_extract_prompt(example: dict[str, list]) -> dict[str, list]: + r""" + Extracts the shared prompt from a preference data example, where the prompt is implicit within both + the chosen and rejected completions. + + If the example already contains a `"prompt"` key, the function returns the example as is. Else, the function + identifies the longest common sequence (prefix) of conversation turns between the "chosen" and "rejected" + completions and extracts this as the prompt. It then removes this prompt from the respective "chosen" and + "rejected" completions. + + Args: + example (`dict[str, list]`): + A dictionary representing a single data entry in the preference dataset. It must contain the keys + `"chosen"` and `"rejected"`, where each value is either conversational or standard (`str`). + + Returns: + `dict[str, list]`: A dictionary containing: + - `"prompt"`: The longest common prefix between the "chosen" and "rejected" completions. + - `"chosen"`: The remainder of the "chosen" completion, with the prompt removed. + - `"rejected"`: The remainder of the "rejected" completion, with the prompt removed. + + Examples: + + ```python + >>> example = { + ... "chosen": [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is blue."} + ... ], + ... "rejected": [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is green."} + ... ] + ... } + >>> extract_prompt(example) + {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]} + ``` + + Or, with the `map` method of `datasets.Dataset`: + + ```python + >>> from trl import extract_prompt + >>> from datasets import Dataset + >>> dataset_dict = { + ... "chosen": [ + ... [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is blue."}, + ... ], + ... [ + ... {"role": "user", "content": "Where is the sun?"}, + ... {"role": "assistant", "content": "In the sky."}, + ... ], + ... ], + ... "rejected": [ + ... [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is green."}, + ... ], + ... [ + ... {"role": "user", "content": "Where is the sun?"}, + ... {"role": "assistant", "content": "In the sea."}, + ... ], + ... ], + ... } + >>> dataset = Dataset.from_dict(dataset_dict) + >>> dataset = dataset.map(extract_prompt) + >>> dataset[0] + {'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], + 'chosen': [{'role': 'assistant', 'content': 'It is blue.'}], + 'rejected': [{'role': 'assistant', 'content': 'It is green.'}]} + ``` + """ + # Some dataset add a `"prompt"` column, even though the prompt is implicit and included in the "chosen" and + # "rejected" completions. E.g.: + # {"prompt": "What color is the sky?", + # "chosen": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}], + # "rejected": [{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}]} + # That's why we check if the prompt is also conversational before deciding not to extract it. + if "chosen" not in example or "rejected" not in example: # not a preference example + return example + if "prompt" in example: + # Both conversational or both non-conversational + chosen_conv = is_conversational({"chosen": example["chosen"]}) + prompt_conv = is_conversational({"prompt": example["prompt"]}) + if (chosen_conv and prompt_conv) or (not chosen_conv and not prompt_conv): + return example + return extract_prompt({"chosen": example["chosen"], "rejected": example["rejected"]}) + + +def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str, list[list]]: + """ + Pack examples into chunks of size `seq_length`. + + Args: + examples (`dict[str, list[list]]`): + Dictionary of examples with keys as strings and values as lists of lists. + seq_length (`int`): + Maximum sequence length. + + Returns: + `dict[str, list[list]]`: Dictionary of examples with keys as strings and values as lists of lists. + + Example: + + ```python + >>> from trl import pack_examples + >>> examples = { + ... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + ... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + ... } + >>> pack_examples(examples, seq_length=5) + {'input_ids': [[1, 2, 3, 4, 5], [6, 7, 8]], 'attention_mask': [[0, 1, 1, 0, 0], [1, 1, 1]]} + >>> pack_examples(examples, seq_length=2) + {'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]} + ``` + """ + warnings.warn( + "`pack_examples` is deprecated and will be removed in version 0.20.0. Use `pack_dataset` with a dataset " + "instead.", + DeprecationWarning, + ) + # Join all the values into a single list + examples = {k: sum(v, []) for k, v in examples.items()} + # Split the values into chunks of size seq_length + examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()} + return examples + + +class _SegmentTree: + """ + A segment tree data structure that, when initialized as `_SegmentTree(maxval)`, efficiently finds the next larger value + for a given input within the range [1, maxval]. + + See [Fewer Truncations Improve Language Modeling](https://arxiv.org/abs/2404.10830) for more details. + """ + + def __init__(self, maxval: int): + self.maxval = maxval + self.tree = [0] * (2 * maxval) + + def add(self, val): + assert 0 < val <= self.maxval + i = self.maxval + val - 1 + self.tree[i] = val + while i > 1: + i >>= 1 + left, right = self.tree[i << 1], self.tree[(i << 1) + 1] + # Compare the values using if-else otherwise repeated calls to `builtins.max` become the bottleneck + self.tree[i] = left if left >= right else right + + def remove(self, val): + assert 0 < val <= self.maxval + i = self.maxval + val - 1 + self.tree[i] = 0 + while i > 1: + i >>= 1 + left, right = self.tree[i << 1], self.tree[(i << 1) + 1] + # Compare the values using if-else otherwise repeated calls to `builtins.max` become the bottleneck + self.tree[i] = left if left >= right else right + + def search(self, val): + assert 0 < val <= self.maxval + i = 1 + while i < self.maxval: + if self.tree[i << 1] >= val: + i = i << 1 + else: + i = (i << 1) + 1 + return self.tree[i] + + +def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table: + """Pack sequences in a pyarrow Table using First Fit Decreasing strategy.""" + # Add position_ids to the examples + input_ids = examples["input_ids"] + position_ids_python = [list(range(len(sequence))) for sequence in input_ids.to_pylist()] + position_ids_array = pa.array(position_ids_python, type=examples["input_ids"].type) + examples = examples.append_column("position_ids", position_ids_array) + + columns = [] + list_column_idx = None + for idx, column in enumerate(examples.columns): + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + column = pc.list_slice(column, 0, seq_length) + if list_column_idx is None: + list_column_idx = idx + columns.append(column) + examples = pa.Table.from_arrays(columns, names=examples.column_names) + + ids = np.arange(len(examples)) + assert list_column_idx is not None + lengths = pc.make_struct(pc.list_value_length(examples[list_column_idx]).combine_chunks(), ids) + lengths = lengths.sort("descending", by=0) + + segment_tree = _SegmentTree(seq_length) + segment_tree.add(seq_length) # the max, `seq_length` bin is always available + space_to_bin = defaultdict(deque) + + # Bin is represented as a dict (of example ids and sum of their lengths) to allow in-place updates + bins: list[dict] = [] + for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy()): + space = segment_tree.search(length) + + if space < seq_length: + bin = space_to_bin[space].popleft() + else: + bin = {"ids": [], "length": 0} + bins.append(bin) + + bin["ids"].append(idx) + bin["length"] += length + if space < seq_length and not space_to_bin[space]: + segment_tree.remove(space) + + space = space - length + space_to_bin[space].append(bin) + if space > 0: + segment_tree.add(space) + + examples = pc.take(examples, [id_ for bin in bins for id_ in bin["ids"]]) + offsets = np.array([0] + [bin["length"] for bin in bins]) + offsets = np.cumsum(offsets) + + columns = [] + for column in examples.columns: + assert len(column.chunks) == 1 # `pc.take` returns a ChunkedArray with a single chunk + column = column.chunks[0] + if pa.types.is_list(column.type) or pa.types.is_large_list(column.type): + dtype = column.offsets.type.to_pandas_dtype() + column = type(column).from_arrays(offsets.astype(dtype), column.values) + columns.append(column) + return pa.Table.from_arrays(columns, names=examples.column_names) + + +def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table: + """Pack sequences in a pyarrow Table using a wrapped strategy.""" + columns = [] + for column in examples.columns: + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + if isinstance(column, pa.ChunkedArray): + column = column.combine_chunks() + offsets, values = column.offsets, column.values + values = values[offsets[0].as_py() : offsets[-1].as_py()] + num_elements = len(values) + dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64 + offsets = np.arange(0, num_elements, seq_length, dtype=dtype) + offsets = np.concatenate((offsets, [num_elements])) + column = type(column).from_arrays(offsets, values) + columns.append(column) + return pa.Table.from_arrays(columns, names=examples.column_names) + + +def pack_dataset( + dataset: DatasetType, seq_length: int, strategy: str = "ffd", map_kwargs: Optional[dict[str, Any]] = None +) -> DatasetType: + r""" + Pack sequences in a dataset into chunks of size `seq_length`. + + Args: + dataset (`Dataset` or `DatasetDict`): + Dataset to pack + seq_length (`int`): + Target sequence length to pack to. + strategy (`str`, *optional*, defaults to `"ffd"`): + Packing strategy to use. Can be either: + + - `"ffd"` (First Fit Decreasing): Slower but preserves sequence boundaries. Sequences are never cut in the + middle. + - `"wrapped"`: Faster but more aggressive. Ignores sequence boundaries and will cut sequences in the middle + to completely fill each packed sequence with data. + map_kwargs (`dict` or `None`, *optional*, defaults to `None`): + Additional keyword arguments to pass to the dataset's map method when packing examples. + + Returns: + `Dataset` or `DatasetDict`: The dataset with packed sequences. The number of examples may + decrease as sequences are combined. + + Example: + ```python + >>> from datasets import Dataset + >>> from trl import pack_dataset + >>> examples = { + ... "input_ids": [[1, 2, 3], [4, 5], [6, 7, 8], [9]], + ... "attention_mask": [[1, 1, 0], [1, 0], [1, 0, 0], [1]] + ... } + >>> dataset = Dataset.from_dict(examples) + >>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd") + >>> packed_dataset[:] + {'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]], + 'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]} + ``` + """ + if map_kwargs is None: + map_kwargs = {} + # Fast packing with pyarrow + dataset = dataset.with_format("arrow") + if strategy == "ffd": + dataset = dataset.map(_pack_ffd, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs) + elif strategy == "wrapped": + dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs) + else: + raise ValueError(f"Invalid packing strategy: {strategy}. Use 'ffd' or 'wrapped'.") + dataset = dataset.with_format(None) + return dataset + + +def truncate_dataset( + dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None +) -> DatasetType: + r""" + Truncate sequences in a dataset to a specifed `max_length`. + + Args: + dataset (`Dataset` or `DatasetDict`): + Dataset to truncate. + seq_length (`int`): + Maximum sequence length to truncate to. + map_kwargs (`dict` or `None`, *optional*, defaults to `None`): + Additional keyword arguments to pass to the dataset's map method when truncating examples. + + Returns: + `Dataset` or `DatasetDict`: The dataset with truncated sequences. + + Example: + ```python + >>> from datasets import Dataset + >>> examples = { + ... "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], + ... "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], + ... } + >>> dataset = Dataset.from_dict(examples) + >>> truncated_dataset = truncate_dataset(dataset, max_length=2) + >>> truncated_dataset[:] + {'input_ids': [[1, 2], [4, 5], [8]], + 'attention_mask': [[0, 1], [0, 0], [1]]} + ``` + """ + if map_kwargs is None: + map_kwargs = {} + if isinstance(dataset, Dataset): + # Fast truncation with pyarrow + def truncate(examples): + truncated_columns = [] + for column in examples.columns: + if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type): + column = pc.list_slice(column, 0, max_length) + truncated_columns.append(column) + return pa.Table.from_arrays(truncated_columns, names=examples.column_names) + + dataset = dataset.with_format("arrow") + dataset = dataset.map(truncate, batched=True, **map_kwargs) + dataset = dataset.with_format(None) + else: + + def truncate(examples): + truncated_examples = {} + for key, column in examples.items(): + if column and isinstance(column[0], list): + column = [val[:max_length] for val in column] + truncated_examples[key] = column + return truncated_examples + + dataset = dataset.map( + truncate, + batched=True, + **map_kwargs, + ) + return dataset + + +def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]: + """ + Convert a conversational dataset with fields `from` and `value` to ChatML format. + + This function modifies conversational data to align with OpenAI's ChatML format: + - Replaces the key `"from"` with `"role"` in message dictionaries. + - Replaces the key `"value"` with `"content"` in message dictionaries. + - Renames `"conversations"` to `"messages"` for consistency with ChatML. + + Args: + example (`dict[str, list]`): + A single data entry containing a list of messages. + + Returns: + `dict[str, list]`: + Example reformatted to ChatML style. + + Example: + ```python + >>> from trl import maybe_convert_to_chatml + >>> example = { + ... "conversations": [ + ... {"from": "user", "value": "What color is the sky?"}, + ... {"from": "assistant", "value": "It is blue."} + ... ] + ... } + >>> maybe_convert_to_chatml(example) + {'messages': [{'role': 'user', 'content': 'What color is the sky?'}, + {'role': 'assistant', 'content': 'It is blue.'}]} + ``` + """ + # List of possible keys containing message lists + for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]: + if key in example and isinstance(example[key], list): + messages = example[key] + for message in messages: + if isinstance(message, dict): + if "from" in message: + message["role"] = message.pop("from") + if "value" in message: + message["content"] = message.pop("value") + + # Rename "conversations" to "messages" + if "conversations" in example: + example["messages"] = example.pop("conversations") + + return example diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..70ab98f67c6db7e8b9cff63d60ad0ca1e84587c0 --- /dev/null +++ b/trl/environment/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "base_environment": ["TextEnvironment", "TextHistory"], +} + +if TYPE_CHECKING: + from .base_environment import TextEnvironment, TextHistory +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py new file mode 100644 index 0000000000000000000000000000000000000000..5207b407387dd8c7a543d83fc58322d4a4a6aa8a --- /dev/null +++ b/trl/environment/base_environment.py @@ -0,0 +1,482 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings +from typing import Optional + +import torch +from accelerate.utils import extract_model_from_parallel +from transformers import StoppingCriteria, StoppingCriteriaList +from transformers.utils import is_rich_available + + +if is_rich_available(): + from rich import print + from rich.text import Text + + +class StringStoppingCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" + + def __init__(self, stop_strings, tokenizer): + self.stop_strings = stop_strings + self.tokenizer = tokenizer + self.first_call = True + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the stop strings.""" + if self.first_call: + self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.start_length = input_ids.shape[-1] - 1 + self.first_call = False + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + + for i, decoded_generation in enumerate(decoded_generations): + sequence_complete = any(stop_string in decoded_generation for stop_string in self.stop_strings) + done.append(sequence_complete) + if not sequence_complete: + self.generated_tokens[i] += 1 + + if all(done): + self.first_call = True + + return all(done) + + +class TextHistory: + """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" + + def __init__(self, text, tokens, system=True): + """ + Initialize TextHistory. + + Args: + text (`str`): The text of the first segment. + tokens (`torch.LongTensor`): The tokens of the first segment. + system (`bool`, *optional*): Whether the first segment is a system or user segment. + """ + self.system_spans = [] + self.text_spans = [] + self.token_spans = [] + self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.text = "" + self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) + self.completed = False + self.truncated = False + self.reward = 0.0 + + self.prompt_color = "black on grey85" + self.system_color = "black on cyan3" + self.model_color = "black on deep_sky_blue1" + self.reward_color = "black on plum1" + + self.append_segment(text, tokens, system=system) + + def append_segment(self, text, tokens, system=True): + """ + Append a new segment to the history. + + Args: + text (`str`): The text of the new segment. + tokens (`torch.LongTensor`): The tokens of the new segment. + system (`bool`, *optional*): Whether the new segment is a system or user segment. + """ + + if len(text) == 0 or len(tokens) == 0: + raise ValueError("Can't append empty text or token list to history.") + + original_text_length = len(self.text) + + self.text += text + self.text_spans.append((original_text_length, len(self.text))) + self.system_spans.append(system) + + original_token_length = len(self.tokens) + + self.tokens = torch.cat((self.tokens, tokens)) + if system: + self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) + else: + self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + self.token_spans.append((original_token_length, len(self.tokens))) + + def complete(self, truncated=False): + """ + Mark the history as completed. + """ + self.completed = True + self.truncated = truncated + + @property + def last_text_segment(self): + """ + Get the last text segment. + """ + start, end = self.text_spans[-1] + return self.text[start:end] + + def split_query_response_tokens(self): + """ + Split the tokens into query and response tokens. + """ + split_index = self.token_spans[0][1] + query = self.tokens[:split_index] + response = self.tokens[split_index:] + mask = self.token_masks[split_index:] + + return query, response, mask + + def show_text(self, show_legend=False): + """ + Print the text history. + """ + if not is_rich_available(): + raise ImportError( + "The `rich` library is required to display text with formatting. Install it using `pip install rich`." + ) + + text = Text(self.text) + text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) + for i, (start, end) in enumerate(self.text_spans[1:]): + if self.system_spans[i + 1]: + text.stylize(self.system_color, start, end) + else: + text.stylize(self.model_color, start, end) + + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + + if show_legend: + self.show_colour_legend() + + def show_tokens(self, tokenizer, show_legend=False): + """ + Print the history tokens. + """ + if not is_rich_available(): + raise ImportError( + "The `rich` library is required to display tokens with formatting. " + "Install it using `pip install rich`." + ) + + text = Text() + prompt_end = self.token_spans[0][1] + for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): + if i < prompt_end: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) + text.append(" ") + elif mask == 0: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) + text.append(" ") + else: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) + text.append(" ") + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + if show_legend: + self.show_colour_legend() + + def show_colour_legend(self): + """ + Print the colour legend. + """ + if not is_rich_available(): + raise ImportError( + "The `rich` library is required to display colour legends with formatting. " + "Install it using `pip install rich`." + ) + text = Text("\n\n(Colour Legend: ") + text.append("Prompt", style=self.prompt_color) + text.append("|") + text.append("System", style=self.system_color) + text.append("|") + text.append("Model", style=self.model_color) + text.append("|") + text.append("Reward", style=self.reward_color) + text.append(")") + print(text) + + +class TextEnvironment: + """ + The TextEnvironment enables interaction of a LLM with an environment using tools. + """ + + def __init__( + self, + model=None, + tokenizer=None, + tools=None, + reward_fn=None, + prompt=None, + max_turns=4, + max_tool_response=100, + max_length=None, + generation_kwargs=None, + ): + """ + Initialize TextEnvironment. + + Args: + model (`PreTrainedModelWrapper`): The model to use for generation. + tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. + tools (list): A list of tools to use for interaction. + reward_fn (function): A function that takes a string and returns a reward. + prompt (str): The base prompt to use for generation. Is prepended to the tasks. + max_turns (Optional[int]): The maximum number of turns to allow. + max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. + max_length (Optional[int]): The maximum number of tokens to allow in an episode. + generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + """ + warnings.warn( + "This class is deprecated and will be removed in version 0.21.0. To enable tool use with LLMs, check out smolagents (https://huggingface.co/docs/smolagents/index)", + DeprecationWarning, + ) + self.model = model + self.tokenizer = tokenizer + self.prompt = prompt + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = {tool.__class__.__name__: tool for tool in tools} + self.reward_fn = reward_fn + self.max_length = max_length + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.max_turns = max_turns + self.max_tool_response = max_tool_response + + if generation_kwargs is None: + self.generation_kwargs = dict() + else: + self.generation_kwargs = generation_kwargs + + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.current_device = extract_model_from_parallel(self.model).pretrained_model.device + + def run(self, queries, **rewards_kwargs): + """ + Run the environment on a list of queries. + + Args: + queries (list[str]): A list of queries to run the model in the environment on. + """ + turns = 0 + + queries = [self.prompt + task for task in queries] + queries_tokens = [ + self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) + for query in queries + ] + + histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + + while any(not history.completed for history in histories) and turns < self.max_turns: + histories = self.generate(histories) + histories = self.tasks_end_check(histories) + # TODO: make this parallel rather than for-loop + for i in range(len(histories)): + histories[i] = self.step(histories[i]) + histories = self.tasks_end_check(histories, model_turn=False) + turns += 1 + self.compute_reward(histories, **rewards_kwargs) + + # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively + queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) + + rewards = [history.reward for history in histories] + return queries, responses, masks, rewards, histories + + def step(self, history): + """ + Step the environment forward one turn. + + Args: + history (`TextHistory`): The history to step forward. + """ + truncated, ended = self.task_end_check(history) + if ended: + history.complete(truncated=truncated) + if history.completed: + return history + + tool, query = self.parse_tool_call(history.last_text_segment) + if tool is None or query is None: + response = f"Unknown tool call: {history.last_text_segment}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool error: {str(error)}" + + if len(response) > self.max_tool_response: + response = response[: (self.max_tool_response - 3)] + "..." + + history.append_segment( + response + self.response_token, + self.tokenizer(response + self.response_token, return_tensors="pt") + .input_ids[0] + .to(self.model.pretrained_model.device), + system=True, + ) + + return history + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def compute_reward(self, histories, **reward_kwargs): + """ + Compute the reward for a list of histories. + """ + rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) + for history, reward in zip(histories, rewards): + history.reward = reward + return histories + + def generate(self, histories): + """ + Generate responses for a list of histories. + """ + active_histories = [i for i, history in enumerate(histories) if not history.completed] + + query_tensors = [histories[i].tokens for i in active_histories] + response_tensors = self._generate_batched(query_tensors) + response_texts = self.tokenizer.batch_decode(response_tensors) + + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + histories[i].append_segment(response_text, response_tensor, system=False) + + return histories + + def tasks_end_check(self, histories, model_turn=True): + """ + Check if the current generation sequences have finished. + """ + for history in histories: + if not history.completed: + truncated, ended = self.task_end_check(history, model_turn=model_turn) + if ended: + history.complete(truncated=truncated) + return histories + + def task_end_check(self, history, model_turn=True): + """ + Check if the current generation sequence has finished. + """ + truncated = False + ended = False + if history.completed: + return truncated, ended + if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + truncated = True + ended = True + elif self.tokenizer.eos_token in history.text: + ended = True + elif model_turn and not ( + (self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) + or self.submit_token in history.last_text_segment + ): + ended = True + elif self.submit_token in history.last_text_segment: + ended = True + return truncated, ended + + def _generate_batched( + self, + query_tensors, + batch_size: int = 16, + pad_to_multiple_of: Optional[int] = None, + ): + """ + Generate responses for a list of query tensors. + + Args: + query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + batch_size (int): The batch size to use for generation. + pad_to_multiple_of (int): The padding length to use for generation. + """ + outputs = [] + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) + + self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + for generation, mask, generated_tokens in zip( + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + ): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + # remove chunk generated after stopping criteria in batch mode + outputs.append(output[:generated_tokens]) + self.tokenizer.padding_side = padding_side_default + return outputs diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fddcdf1af1a6353717f446e555faeebfaf169b08 --- /dev/null +++ b/trl/extras/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "best_of_n_sampler": ["BestOfNSampler"], +} + +if TYPE_CHECKING: + from .best_of_n_sampler import BestOfNSampler +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/extras/best_of_n_sampler.py b/trl/extras/best_of_n_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..d51323b5a96e164429f821a9d2228f8c515e5b71 --- /dev/null +++ b/trl/extras/best_of_n_sampler.py @@ -0,0 +1,130 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Optional, Union + +import torch +from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed + +from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper + + +class BestOfNSampler: + def __init__( + self, + model: PreTrainedModelWrapper, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + queries_to_scores: Callable[[list[str]], list[float]], + length_sampler: Any, + sample_size: int = 4, + seed: Optional[int] = None, + n_candidates: int = 1, + generation_config: Optional[GenerationConfig] = None, + ) -> None: + r""" + Initialize the sampler for best-of-n generation + + Args: + model (`PreTrainedModelWrapper`): + The pretrained model to use for generation + tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): + Tokenizer associated with the pretrained model + queries_to_scores (`Callable[[list[str]], list[float]]`): + Callable that takes a list of generated texts and returns the associated reward scores + length_sampler (`Any`): + Sampler used to sample the length of the generated text + sample_size (`int`): + Number of samples to generate for each query + seed (`int`, *optional*): + Random seed used to control generation + n_candidates (`int`): + Number of candidates to return for each query + generation_config (`GenerationConfig`, *optional*): + Generation config passed to the underlying model's `generate` method. + See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details + """ + if seed is not None: + set_seed(seed) + + if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): + raise ValueError( + f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" + ) + if not isinstance(model, (SUPPORTED_ARCHITECTURES)): + raise ValueError( + f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" + ) + + self.model = model + self.tokenizer = tokenizer + + self.queries_to_scores = queries_to_scores + self.length_sampler = length_sampler + self.gen_config = generation_config + self.sample_size = sample_size + self.n_candidates = n_candidates + + def generate( + self, + tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]], + skip_special_tokens: bool = True, + device: Optional[Union[str, torch.device]] = None, + **generation_kwargs, + ) -> list[list[str]]: + r""" + Generate the best of n samples for input queries + + Args: + tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[int]`): + represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) + skip_special_tokens (`bool`): + Whether to remove the special tokens from the output + device (`str` or `torch.device`, *optional*): + The device on which the model will be loaded + **generation_kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's `generate` method. + This is used to override generation config + + Returns: + list[list[str]]: A list of lists of generated texts + """ + queries = None + + if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: + queries = tokenized_query.unsqueeze(0) + elif isinstance(tokenized_query, list): + element_type = type(tokenized_query[0]) + if element_type is int: + queries = torch.tensor(tokenized_query).unsqueeze(0) + elif element_type is torch.Tensor: + queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] + else: + queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] + + result = [] + + for query in queries: + queries = query.repeat((self.sample_size, 1)) + output = self.model.generate( + queries.to(device), + max_new_tokens=self.length_sampler(), + generation_config=self.gen_config, + **generation_kwargs, + ).squeeze() + output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) + scores = torch.tensor(self.queries_to_scores(output)) + output = [output[i] for i in scores.topk(self.n_candidates).indices] + result.append(output) + + return result diff --git a/trl/extras/dataset_formatting.py b/trl/extras/dataset_formatting.py new file mode 100644 index 0000000000000000000000000000000000000000..f648976e3e475893add091eb4ce74abc40c91064 --- /dev/null +++ b/trl/extras/dataset_formatting.py @@ -0,0 +1,106 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable, Literal, Optional, Union + +from datasets import Dataset, Value +from transformers import AutoTokenizer + +from ..trainer.utils import ConstantLengthDataset + + +FORMAT_MAPPING = { + "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], + "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, +} + + +def conversations_formatting_function( + tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None +): + r""" + return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset along with the schema of the list of functions in the tools list. + """ + + def format_dataset(examples): + if isinstance(examples[messages_field][0], list): + output_texts = [] + for i in range(len(examples[messages_field])): + output_texts.append( + tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools) + ) + return output_texts + else: + return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools) + + return format_dataset + + +def instructions_formatting_function(tokenizer: AutoTokenizer): + r""" + return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer + apply chat template to the dataset + """ + + def format_dataset(examples): + if isinstance(examples["prompt"], list): + output_texts = [] + for i in range(len(examples["prompt"])): + converted_sample = [ + {"role": "user", "content": examples["prompt"][i]}, + {"role": "assistant", "content": examples["completion"][i]}, + ] + output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) + return output_texts + else: + converted_sample = [ + {"role": "user", "content": examples["prompt"]}, + {"role": "assistant", "content": examples["completion"]}, + ] + return tokenizer.apply_chat_template(converted_sample, tokenize=False) + + return format_dataset + + +def get_formatting_func_from_dataset( + dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None +) -> Optional[Callable]: + r""" + Finds the correct formatting function based on the dataset structure. Currently supported datasets are: + - `ChatML` with [{"role": str, "content": str}] + - `instruction` with [{"prompt": str, "completion": str}] + + Args: + dataset (Dataset): User dataset + tokenizer (AutoTokenizer): Tokenizer used for formatting + + Returns: + Callable: Formatting function if the dataset format is supported else None + """ + if isinstance(dataset, Dataset): + if "messages" in dataset.features: + if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "messages", tools) + if "conversations" in dataset.features: + if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: + logging.info("Formatting dataset with chatml format") + return conversations_formatting_function(tokenizer, "conversations", tools) + elif dataset.features == FORMAT_MAPPING["instruction"]: + logging.info("Formatting dataset with instruction format") + return instructions_formatting_function(tokenizer) + + return None diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py new file mode 100644 index 0000000000000000000000000000000000000000..2b763ea210fc31d5d03e1247b64eb7fba2e8fd04 --- /dev/null +++ b/trl/extras/profiling.py @@ -0,0 +1,98 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import time +from collections.abc import Generator + +from transformers import Trainer +from transformers.integrations import is_mlflow_available, is_wandb_available + + +if is_wandb_available(): + import wandb + +if is_mlflow_available(): + import mlflow + + +@contextlib.contextmanager +def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: + """ + A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow + depending on the trainer's configuration. + + Args: + trainer (`~transformers.Trainer`): + Trainer object. + name (`str`): + Name of the block to be profiled. Used as a key in the logged dictionary. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_context + + class MyTrainer(Trainer): + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + with profiling_context(self, "matrix_multiplication"): + # Code to profile: simulate a computationally expensive operation + result = A @ B # Matrix multiplication + ``` + """ + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration} + if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: + wandb.log(profiling_metrics) + + if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process: + mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) + + +def profiling_decorator(func: callable) -> callable: + """ + Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. + + Args: + func (`callable`): + Function to be profiled. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_decorator + + class MyTrainer(Trainer): + @profiling_decorator + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + # Code to profile: simulate a computationally expensive operation + result = A @ B + ``` + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + with profiling_context(self, func.__name__): + return func(self, *args, **kwargs) + + return wrapper diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py new file mode 100644 index 0000000000000000000000000000000000000000..4724d96e759e35f9682aec764ef5b5ae5f602cdb --- /dev/null +++ b/trl/extras/vllm_client.py @@ -0,0 +1,329 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import logging +import socket +import time +from typing import Optional +from urllib.parse import urlparse + +import torch +from torch import nn + +from ..import_utils import is_requests_available, is_vllm_ascend_available, is_vllm_available + + +if is_requests_available(): + import requests + from requests import ConnectionError + + +if is_vllm_available(): + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator + + +logger = logging.getLogger(__name__) + + +class VLLMClient: + """ + A client class to interact with a vLLM server. + + This class provides methods to generate completions, initialize and manage weight update groups, and update model + weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`. + + Args: + base_url (`str` or `None`, *optional*, defaults to `None`): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are + ignored. + host (`str`, *optional*, defaults to `"0.0.0.0"`): + IP address of the vLLM server. Ignored if `base_url` is provided. + server_port (`int`, *optional*, defaults to `8000`): + Port number of the vLLM server. Ignored if `base_url` is provided. + group_port (`int`, *optional*, defaults to `51216`): + Port number for the weight update group. + connection_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds to wait for the server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + + Examples: + Run the vLLM server with the model `Qwen/Qwen2.5-7B`: + + ``` + $ trl vllm-serve --model Qwen/Qwen2.5-7B + ... + INFO: Application startup complete. + INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) + ``` + + Use the client to generate completions and update model weights: + + ```python + >>> from trl.extras.vllm_client import VLLMClient + >>> client = VLLMClient() + >>> client.generate(["Hello, AI!", "Tell me a joke"]) + [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], + [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] + + >>> from transformers import AutoModelForCausalLM + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda") + >>> client.init_communicator() + >>> client.update_model_params(model) + ``` + + There are several ways to initialize the client: + + ```python + VLLMClient(base_url="http://localhost:8000") + VLLMClient(base_url="http://192.168.1.100:8000") + VLLMClient(host="localhost", server_port=8000) + VLLMClient(host="192.168.1.100", server_port=8000) + ``` + """ + + def __init__( + self, + base_url: Optional[str] = None, + host: str = "0.0.0.0", + server_port: int = 8000, + group_port: int = 51216, + connection_timeout: float = 0.0, + ): + if not is_requests_available(): + raise ImportError("requests is not installed. Please install it with `pip install requests`.") + if not is_vllm_available(): + raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.") + + self.session = requests.Session() + + if base_url is not None: + # Parse the base_url to extract host and port + parsed_url = urlparse(base_url) + self.host = socket.gethostbyname(parsed_url.hostname) + scheme = parsed_url.scheme or "http" + self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}" + else: + self.host = host + self.server_port = server_port + self.base_url = f"http://{self.host}:{self.server_port}" + self.group_port = group_port + self.check_server(connection_timeout) # check server and fail after timeout + + def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): + """ + Check server availability with retries on failure, within a total timeout duration. If the server is not up + after the total timeout duration, raise a `ConnectionError`. + + Args: + retry_interval (`float`, *optional*, defaults to `2.0`): + Interval in seconds between retries. + total_timeout (`float`, *optional*, defaults to `0.0`): + Total timeout duration in seconds. + """ + url = f"{self.base_url}/health/" + start_time = time.time() # Record the start time + + while True: + try: + response = requests.get(url) + except requests.exceptions.RequestException as exc: + # Check if the total timeout duration has passed + elapsed_time = time.time() - start_time + if elapsed_time >= total_timeout: + raise ConnectionError( + f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make " + "sure the server is running by running `trl vllm-serve`." + ) from exc + else: + if response.status_code == 200: + if "X-Forwarded-For" in response.headers: + self.host = response.headers["X-Forwarded-For"] + logger.info("Server is up!") + return None + + # Retry logic: wait before trying again + logger.info(f"Server is not up yet. Retrying in {retry_interval} seconds...") + time.sleep(retry_interval) + + def generate( + self, + prompts: list[str], + n: int = 1, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + max_tokens: int = 16, + guided_decoding_regex: Optional[str] = None, + ) -> list[list[int]]: + """ + Generates model completions for the provided prompts. + + Args: + prompts (`list[str]`): + List of text prompts for which the model will generate completions. + n (`int`, *optional*, defaults to `1`): + Number of completions to generate for each prompt. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Parameter for repetition penalty. 1.0 means no penalty. + temperature (`float`, *optional*, defaults to `1.0`): + Temperature parameter for sampling. Higher values increase diversity. + top_p (`float`, *optional*, defaults to `1.0`): + Top-p sampling parameter.`1.0` means no truncation. + top_k (`int`, *optional*, defaults to `-1`): + Top-k sampling parameter. `-1` means no truncation. + min_p (`float`, *optional*, defaults to `0.0`): + Minimum probability for sampling. + max_tokens (`int`, *optional*, defaults to `16`): + Maximum number of tokens to generate for each prompt. + guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regular expression to guide the decoding process. + + Returns: + `list[list[int]]`: + List of lists of token IDs representing the model-generated completions for each prompt. + """ + url = f"{self.base_url}/generate/" + response = self.session.post( + url, + json={ + "prompts": prompts, + "n": n, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_tokens": max_tokens, + "guided_decoding_regex": guided_decoding_regex, + }, + ) + if response.status_code == 200: + return response.json()["completion_ids"] + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def init_communicator(self): + """ + Initializes the weight update group in a distributed setup for model synchronization. + """ + # Get the world size from the server + url = f"{self.base_url}/get_world_size/" + response = requests.get(url) + if response.status_code == 200: + vllm_world_size = response.json()["world_size"] + else: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + world_size = vllm_world_size + 1 # add the client to the world + self.rank = vllm_world_size # the client's rank is the last process + + # Initialize weight update group + url = f"{self.base_url}/init_communicator/" + # In the server side, the host is set to 0.0.0.0 + response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size}) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + # Brief delay to allow server initialization. While not strictly required (client socket will retry on + # connection failure), this prevents log warnings like: + # [W416 23:24:57.460001114 socket.cpp:204] [c10d] The hostname of the client socket cannot be retrieved. err=-3 + time.sleep(0.1) + + # Set up the communication group for weight broadcasting + pg = StatelessProcessGroup.create(host=self.host, port=self.group_port, rank=self.rank, world_size=world_size) + self.pynccl_comm = PyNcclCommunicator(pg, device=0) + + # When the client object is deleted, close the weight update group + atexit.register(self.close_communicator) + + def update_named_param(self, name: str, weights: torch.Tensor): + """ + Updates a specific named parameter in the model and broadcasts it to other processes. + + Args: + name (`str`): + Name of the layer whose weights are being updated. + weights (`torch.Tensor`): + Tensor containing the updated weights. + """ + dtype, shape = str(weights.dtype), tuple(weights.shape) + url = f"{self.base_url}/update_named_param/" + response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape}) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + # Broadcast the weights to the other processes + self.pynccl_comm.broadcast(weights, src=self.rank) + self.pynccl_comm.group.barrier() + + def update_model_params(self, model: nn.Module): + """ + Updates all parameters of the given model by calling `update_named_param` for each parameter in the model. + + Args: + model (`nn.Module`): + Model whose parameters (weights/biases) are to be updated. + """ + for name, param in model.named_parameters(): + # Update each parameter individually + self.update_named_param(name, param.data) + + def reset_prefix_cache(self): + """ + Resets the prefix cache for the model. + """ + url = f"{self.base_url}/reset_prefix_cache/" + response = self.session.post(url) + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + def close_communicator(self): + """ + Closes the weight update group and cleans up the communication group. + """ + url = f"{self.base_url}/close_communicator/" + + try: + response = self.session.post(url) + except ConnectionError: + # The server might be already down, so we don't need to close the communicator + pass + else: + if response.status_code != 200: + raise Exception(f"Request failed: {response.status_code}, {response.text}") + + +# Example usage +if __name__ == "__main__": + from vllm import SamplingParams + + client = VLLMClient() + client.init_communicator() + + # Generate completions + responses = client.generate(["Hello, AI!", "Tell me a joke"], n=4, max_tokens=32, sampling_params=SamplingParams()) + print("Responses:", responses) # noqa + + # Update model weights + from transformers import AutoModelForCausalLM + + model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B").to("cuda") + client.update_model_params(model) diff --git a/trl/import_utils.py b/trl/import_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd594e36ef6b4da68e3ee8f7ab59125e8ea53645 --- /dev/null +++ b/trl/import_utils.py @@ -0,0 +1,156 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from itertools import chain +from types import ModuleType +from typing import Any + +from packaging import version +from transformers.utils.import_utils import _is_package_available + + +LIGER_KERNEL_MIN_VERSION = "0.5.8" + +# Use same as transformers.utils.import_utils +_deepspeed_available = _is_package_available("deepspeed") +_diffusers_available = _is_package_available("diffusers") +_fastapi_available = _is_package_available("fastapi") +_is_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) +_llm_blender_available = _is_package_available("llm_blender") +_mergekit_available = _is_package_available("mergekit") +_pydantic_available = _is_package_available("pydantic") +_requests_available = _is_package_available("requests") +_unsloth_available = _is_package_available("unsloth") +_uvicorn_available = _is_package_available("uvicorn") +_vllm_available = _is_package_available("vllm") +_vllm_ascend_available = _is_package_available("vllm_ascend") +_joblib_available = _is_package_available("joblib") + + +def is_deepspeed_available() -> bool: + return _deepspeed_available + + +def is_diffusers_available() -> bool: + return _diffusers_available + + +def is_fastapi_available() -> bool: + return _fastapi_available + + +def is_liger_kernel_available(min_version: str = LIGER_KERNEL_MIN_VERSION) -> bool: + return _is_liger_kernel_available and version.parse(_liger_kernel_version) >= version.parse(min_version) + + +def is_llm_blender_available() -> bool: + return _llm_blender_available + + +def is_mergekit_available() -> bool: + return _mergekit_available + + +def is_pydantic_available() -> bool: + return _pydantic_available + + +def is_requests_available() -> bool: + return _requests_available + + +def is_unsloth_available() -> bool: + return _unsloth_available + + +def is_uvicorn_available() -> bool: + return _uvicorn_available + + +def is_vllm_available() -> bool: + return _vllm_available + + +def is_vllm_ascend_available() -> bool: + return _vllm_ascend_available + + +def is_joblib_available() -> bool: + return _joblib_available + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) + + +class OptionalDependencyNotAvailable(BaseException): + """Internally used error class for signalling an optional dependency was not found.""" diff --git a/trl/mergekit_utils.py b/trl/mergekit_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5eb52484e9bc1d4b37517405df2fad37dff9d87 --- /dev/null +++ b/trl/mergekit_utils.py @@ -0,0 +1,281 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from huggingface_hub import HfApi + +from trl.import_utils import is_mergekit_available + + +if is_mergekit_available(): + from mergekit.config import MergeConfiguration + from mergekit.merge import MergeOptions, run_merge + + +def upload_model_to_hf(folder_path: str, repo_id: str): + api = HfApi() + # Create the repository if it doesn't exist + repo = api.create_repo(repo_id, repo_type="model") + + # Upload the folder to the specified repository + api.upload_folder( + folder_path=folder_path, + repo_id=repo.repo_id, + repo_type=repo.repo_type, + ) + + +class MergeConfig: + r""" + Configuration class for merging two models using `mergekit`. + + This class provides a structured way to configure and generate merge configurations for various merge methods, + such as `linear`, `ties`, `dare_ties`, and `slerp`. + + Args: + method (`str`, *optional*, defaults to `"linear"`): + Merge method to use. Supported methods include: + + - `"linear"`: Linearly combines two models with specified weights. + - `"ties"`: Combines two models using the TIES method with density parameters. + - `"dare_ties"`: A variant of TIES for domain adaptation. + - `"slerp"`: Combines models using spherical linear interpolation. + + Note: + + For more details about the merge methods and how they are implemented, see the + [MergeKit GitHub repository](https://github.com/arcee-ai/mergekit?tab=readme-ov-file#merge-methods). + + Attributes: + method (`str`): The merge method to use. + policy_model_path (`str` or `None`): Path to the policy model. + target_model_path (`str` or `None`): Path to the target model. + policy_model_weight (`float`): Weight for the policy model (for `linear` and `ties` methods). + target_model_weight (`float`): Weight for the target model (for `linear` and `ties` methods). + policy_model_density (`list[float]`): Density parameters for the policy model (for `ties` and `dare_ties`). + target_model_density (`list[float]`): Density parameters for the target model (for `ties` and `dare_ties`). + normalize (`float` or `None`): Normalization factor for the TIES method. + t_values (`float` or `None`): Interpolation factor for the SLERP method. + dtype (`str`): Data type to use for merging, e.g., `"float16"`. + """ + + def __init__(self, method: str = "linear"): + if not is_mergekit_available(): + raise ImportError("MergeConfig requires the `mergekit` extra. To install, run `pip install mergekit`.") + self.method = method + self.policy_model_path = None + self.target_model_path = None + + # Initialize relevant parameters based on the method + if method == "linear": + self.policy_model_weight = 0.5 + self.target_model_weight = 0.5 + self.dtype = "float16" + elif method == "ties": + self.policy_model_weight = 1.0 + self.policy_model_density = [1.0, 0.7, 0.1] + self.target_model_weight = 1.0 + self.target_model_density = [1.0] + self.normalize = 1.0 + self.dtype = "float16" + elif method == "dare_ties": + self.policy_model_weight = 1.0 + self.policy_model_density = [1.0, 0.7, 0.1] + self.target_model_weight = 1.0 + self.target_model_density = [1.0] + self.normalize = 1.0 + self.dtype = "float16" + elif method == "slerp": + self.t_values = 0.5 + self.dtype = "float16" + else: + raise ValueError(f"Unsupported merge method: {method}") + + def create_merge_config_linear(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a linear merge of two models with specified weights. + """ + # Create the merge configuration dictionary + merge_config_dict = { + "dtype": self.dtype, + "merge_method": "linear", + "models": [ + {"model": self.policy_model_path, "parameters": {"weight": self.policy_model_weight}}, + {"model": self.target_model_path, "parameters": {"weight": self.target_model_weight}}, + ], + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_ties(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a TIES merge of two models, with specified weights and densities. + """ + # Create the TIES merge configuration dictionary + merge_config_dict = { + "merge_method": "ties", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, + }, + { + "model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, + }, + ], + "parameters": {"normalize": self.normalize}, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_dare_ties(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a DARE TIES merge of two models, with specified weights and densities. + """ + # Create the DARE TIES merge configuration dictionary + merge_config_dict = { + "merge_method": "dare_ties", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.target_model_density, "weight": self.target_model_weight}, + }, + { + "model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": {"density": self.policy_model_density, "weight": self.policy_model_weight}, + }, + ], + "parameters": {"normalize": self.normalize}, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create_merge_config_slerp(self) -> "MergeConfiguration": + """ + Creates a merge configuration for a SLERP merge of a model with a base model. + """ + + # Create the SLERP merge configuration dictionary + merge_config_dict = { + "merge_method": "slerp", + "slices": None, # Optional slices if needed + "models": [ + { + "model": { + "model": {"path": self.target_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "parameters": None, # No specific parameters for SLERP model + } + ], + "parameters": { + "t": self.t_values # Set the t values for SLERP + }, + "base_model": { + "model": {"path": self.policy_model_path, "revision": None}, + "lora": None, + "override_architecture": None, + }, + "dtype": self.dtype, + "tokenizer_source": None, + "tokenizer": None, + "chat_template": None, + "out_dtype": None, + } + + # Create the MergeConfiguration from the dictionary + merge_config = MergeConfiguration.model_validate(merge_config_dict) + + return merge_config + + def create(self) -> "MergeConfiguration": + if self.method == "linear": + return self.create_merge_config_linear() + elif self.method == "ties": + return self.create_merge_config_ties() + elif self.method == "dare_ties": + return self.create_merge_config_dare_ties() + elif self.method == "slerp": + return self.create_merge_config_slerp() + + +def merge_models(config: MergeConfig, out_path: str): + """ + Merge two models using mergekit + + Args: + config (`MergeConfig`): The merge configuration. + out_path (`str`): The output path for the merged model. + """ + if not is_mergekit_available(): + raise ImportError("merge_models requires the `mergekit` extra. To install, run `pip install mergekit`.") + run_merge( + config, + out_path=out_path, + options=MergeOptions( + cuda=torch.cuda.is_available(), + copy_tokenizer=True, + lazy_unpickle=False, + low_cpu_memory=False, + ), + ) diff --git a/trl/models/__init__.py b/trl/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2abcb7093e9393b35d42a7f11245a2c704dee3c6 --- /dev/null +++ b/trl/models/__init__.py @@ -0,0 +1,73 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available + + +_import_structure = { + "activation_offloading": ["get_act_offloading_ctx_manager"], + "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], + "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], + "utils": [ + "SUPPORTED_ARCHITECTURES", + "prepare_deepspeed", + "prepare_fsdp", + "setup_chat_format", + "unwrap_model_for_generation", + ], +} + +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_sd_base"] = [ + "DDPOPipelineOutput", + "DDPOSchedulerOutput", + "DDPOStableDiffusionPipeline", + "DefaultDDPOStableDiffusionPipeline", + ] + +if TYPE_CHECKING: + from .activation_offloading import get_act_offloading_ctx_manager + from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model + from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + from .utils import ( + SUPPORTED_ARCHITECTURES, + prepare_deepspeed, + prepare_fsdp, + setup_chat_format, + unwrap_model_for_generation, + ) + + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_sd_base import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/models/activation_offloading.py b/trl/models/activation_offloading.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7b23b0fdffc4335a1d68f5722ef571284aed2b --- /dev/null +++ b/trl/models/activation_offloading.py @@ -0,0 +1,462 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of https://github.com/pytorch/torchtune. + +import warnings + +import psutil +import torch +from torch import nn +from torch.autograd.graph import saved_tensors_hooks + + +class OffloadActivations(saved_tensors_hooks): + """ + Context manager under which activation tensors created in the forward pass will be offloaded. + + Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size` + bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining + the activation on GPU VRAM throughout the program. + + This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and + CPU, which is intended to overlap with the default computation stream to improve runtime. We designed + synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage. + + Args: + use_pin_memory (`bool`, *optional*, defaults to `True`): + Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to + be moved back onto GPU more quickly but is a limited resource. + use_streams (`bool`, *optional*, defaults to `True`): + Whether to use streams for performance optimization where the communications get overlapped with the + computation. Requires a torch build after torch-2.5.0. + min_offload_size (`int`, *optional*, defaults to `1024`): + Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we + do not want to waste bandwidth and resources moving it to CPU and back. + max_fwd_stash_size (`int`, *optional*, defaults to `5`): + Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during + the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow + more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping + alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing + runtime. + + Raises: + ValueError: if `max_fwd_stash_size` is not at least `1`. + + Example: + >>> with OffloadActivations(): + >>> outputs = model(inputs, labels=labels) + >>> loss = outputs.loss + >>> loss.backward() + """ + + def __init__( + self, + use_pin_memory: bool = True, + use_streams: bool = True, + min_offload_size: int = 1024, + max_fwd_stash_size: int = 5, + ) -> None: + self.use_streams = use_streams + + self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors + self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where + self.tensor_id = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + self.is_first_forward_pass = True + + # Managing cpu memory + self.use_pin_memory = use_pin_memory + self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory + + self.accelerator_type = ( + torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" + ) + # NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead + self.s0 = ( + torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream() + ) # comp stream + + # For streaming + if self.use_streams: + self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream + self.fwd_stash = {} # tensor_id => (activation, ev1) + if max_fwd_stash_size < 1: + raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}") + self.max_fwd_stash_size = max_fwd_stash_size + self.bwd_tensor_stash = {} # tensor_id => activation + self.bwd_ev_stash = {} # tensor_id => ev0 + self.curr_graph_id = None + self.curr_autograd_node = None + + # -------- platform util functions -------- # + def verify_sufficient_virtual_memory(): + curr_pct = get_cpu_ram_pct() + if curr_pct > self.virtual_memory_safe_pct: + warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used") + + def get_cpu_ram_pct() -> float: + # get the percentage of memory used by the system + return psutil.virtual_memory().percent + + def get_tensor_id() -> int: + # create a unique id for each tensor we are managing + self.tensor_id += 1 + return self.tensor_id + + def get_num_bytes_tensor(x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes() + + # -------- core pack / unpack work -------- # + def pack_tensor(activation: torch.Tensor) -> int: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward_call: + if len(self.tracker) != 0: + raise ValueError("Backward pass should have cleared tracker of all tensors") + + # set training phase trackers + self.is_first_forward_call = False + self.is_first_backward_call = True + + # query for basic tensor info + num_bytes = get_num_bytes_tensor(activation) + tensor_id = get_tensor_id() + + # only offload hefty bois if they're activations on CUDA (our heuristic + # for that is to check if they're not params or buffers)! + if ( + activation.device.type in ["cuda", "xpu"] + and num_bytes >= self.min_tensor_size_bytes + and ( + not isinstance(activation, torch.nn.Parameter) + and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer)) + ) + ): + if self.use_streams: + # First, sync back and dereference previously offloaded tensors + # as the offloading should be done sufficiently long ago. + for id in list(self.fwd_stash.keys()): + if id <= tensor_id - self.max_fwd_stash_size: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + else: + break + + # Sync in, offload, and add an event to sync back later + self.s1.wait_stream(self.s0) + + stream = self.s1 if self.use_streams else self.s0 + with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream): + cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu") + cpu_tensor.copy_(activation, non_blocking=True) + self.tracker[tensor_id] = ( + cpu_tensor, + True, # True = (in future) modified + ) + + if self.use_streams: + event = self.s1.record_event() + + # Stash to keep activation alive til s1 is done + self.fwd_stash[tensor_id] = (activation, event) + else: + self.tracker[tensor_id] = ( + activation, + False, + ) # False = not modified, tensor is as is + + return tensor_id + + def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + if unpack_tensor_id not in self.tracker: + raise ValueError(f"Untracked tensor with id {unpack_tensor_id}") + + maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) + maybe_accelerator_tensor = accelerator_tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_accelerator_tensor + + def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + self.curr_graph_id = torch._C._current_graph_task_id() + + def wait_and_del_remaining_references() -> None: + for id in list(self.bwd_tensor_stash.keys()): + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + # Register a callback to the end of autograd to clean everything up + torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references) + + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + if unpack_tensor_id not in self.tracker: + raise ValueError(f"untracked tensor with id {unpack_tensor_id}") + + maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + # Get data on the current autograd node + graph_id = torch._C._current_graph_task_id() + node = torch._C._current_autograd_node() + prev_node_ids = [] + + # If we're on a new node, mark prev node's tensors to be freed later + if graph_id == self.curr_graph_id and self.curr_autograd_node != node: + self.curr_autograd_node = node + prev_node_ids = list(self.bwd_tensor_stash.keys()) + + brought_back_from_cpu = True + if unpack_tensor_id in self.fwd_stash: + maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0] + brought_back_from_cpu = False + else: + # Kick off the process to bring tensors back + with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1): + accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True) + maybe_accelerator_tensor = accelerator_tensor + + # Tell comp stream to wait for the info to be loaded before executing + self.s0.wait_stream(self.s1) + + # Stash the tensor to keep memory alive until compute stream is complete + self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor + + # Note: [Track views of the unpacked] + # Why do we get the use count of the unpacked tensor here? We want an + # initial count to compare to later, during the post-hook of the + # backward node, when we need to decide whether we're allowed to free + # the tensor yet. In what obscure cases must we delay freeing the + # tensor (and thus call record_stream)? + # 1. Any of the outputs of the backward node is a view of the unpacked + # tensor. + # 2. In the case that this unpacked tensor will be used in a + # checkpointed region, if one of the recomputed saved tensors ends + # up as a view of the unpacked tensor. + # 3. The user abuses the system somehow and manually relies on the + # unpacked tensor to exist after the backward node has executed. + storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata) + + def hook(outputs, inputs): + # create events for the current node inputs/outputs if they were streamed in + if brought_back_from_cpu: + # See Note: [Track views of the unpacked] + # IF any of the outputs is a view of the tensor, OR if a view of + # the tensor has been saved as a part of checkpoint's recompute + # process, OR the user has abusedly incurred a reference on the + # unpacked tensor, THEN the tensor might be used later and we + # cannot presume to delete it after only the current node is + # done! So we use our frenemy, record_stream, to ensure the + # Tensor stays unmessed with until it's done getting used in the + # compute stream (s0 here). Note that the con here is we introduce + # non-deterministic (thus higher) memory usage, but this case + # should not happen often. + unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id] + if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount: + unpacked_tensor.record_stream(self.s0) + del self.bwd_tensor_stash[unpack_tensor_id] + else: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event + + # if there are still things in the fwd_stash, get rid of them as we're in bwd now + for id in list(self.fwd_stash.keys()): + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + + # wait on prev node's events and del those + for id in prev_node_ids: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + return outputs + + node.register_hook(hook) + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_accelerator_tensor + + unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream + super().__init__(pack_tensor, unpack_tensor) + + +class NoOpManager(saved_tensors_hooks): + """ + A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies + on the behavior that only the most recently registered `saved_tensors_hook` will run. + + One example usage is to opt a local region of code out of activations offloading, which is usually applied globally + to best track state. + """ + + def __init__(self) -> None: + def noop(tensor): + return tensor + + super().__init__(noop, noop) + + +def get_act_offloading_ctx_manager( + model: nn.Module, + use_pin_memory: bool = True, + use_streams: bool = True, + min_offload_size: int = 1024, + max_fwd_stash_size: int = 5, + warn_if_no_head: bool = True, +) -> OffloadActivations: + """ + Returns the activation offloading context manager for the model. All but the last output Linear in every step will + be offloaded. + + If activation offloading is enabled, we return the OffloadActivations context manager. + If activation offloading is disabled, we return a NoOpManager context manager. + + Args: + model (`nn.Module`): + Model to wrap with the activation offloading context manager. + use_pin_memory (`bool`, *optional*, defaults to `True`): + Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to + be moved back onto GPU more quickly but is a limited resource. + use_streams (`bool`, *optional*, defaults to `True`): + Whether to use streams for performance optimization where the communications get overlapped with the + computation. Requires a torch build after torch-2.5.0. + min_offload_size (`int`, *optional*, defaults to `1024`): + Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we + do not want to waste bandwidth and resources moving it to CPU and back. + max_fwd_stash_size (`int`, *optional*, defaults to `5`): + Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during + the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow + more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping + alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing + runtime. + warn_if_no_head (`bool`, *optional*, defaults to `True`): + Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output + head is detected. + + Returns: + `contextlib.ContextDecorator`: + Activation offloading context manager for the model. + """ + activations_handling_ctx = OffloadActivations( + use_pin_memory=use_pin_memory, + use_streams=use_streams, + min_offload_size=min_offload_size, + max_fwd_stash_size=max_fwd_stash_size, + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. + output_head_detected = False + noop_ctx = NoOpManager() + + # Try to get the actual model if it's wrapped + unwrapped_model = model + if hasattr(unwrapped_model, "module"): + unwrapped_model = unwrapped_model.module + # check for PEFT models + if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"): + unwrapped_model = unwrapped_model.base_model + + # Check for different types of output heads + if hasattr(unwrapped_model, "output"): + if isinstance(unwrapped_model.output, nn.Module): + unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module): + unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + + # Check for HuggingFace model output heads + elif hasattr(unwrapped_model, "lm_head"): + unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + + # Check for decoder-based models + elif hasattr(unwrapped_model, "decoder"): + decoder = unwrapped_model.decoder + if hasattr(decoder, "output"): + decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + # Some models have lm_head in the decoder + elif hasattr(decoder, "lm_head"): + decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + + # Check for transformer models with final layer norm + elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"): + final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f + final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + + # Check for models with head module + elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module): + unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + output_head_detected = True + + if not output_head_detected and warn_if_no_head: + warnings.warn( + "During activation offloading, no output head was detected. If your model has an output head, it will be " + "offloaded. This usually greatly slows training, given the large vocabulary size. To change this " + "behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by " + "passing `warn_if_no_head=False`." + ) + + # Disable offloading for any Liger modules + for name, module in unwrapped_model.named_modules(): + if "liger" in name.lower(): + module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) + module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True) + + return activations_handling_ctx diff --git a/trl/models/auxiliary_modules.py b/trl/models/auxiliary_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..2a119363f36b49a30a97404cfa17d812934d5ff6 --- /dev/null +++ b/trl/models/auxiliary_modules.py @@ -0,0 +1,96 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch +import torch.nn as nn +import torchvision +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from transformers import CLIPModel, is_torch_npu_available, is_torch_xpu_available + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + def forward(self, embed): + return self.layers(embed) + + +class AestheticScorer(torch.nn.Module): + """ + This model attempts to predict the aesthetic score of an image. The aesthetic score + is a numerical approximation of how much a specific image is liked by humans on average. + This is from https://github.com/christophschuhmann/improved-aesthetic-predictor + """ + + def __init__(self, *, dtype, model_id, model_filename): + super().__init__() + self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.normalize = torchvision.transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] + ) + self.target_size = 224 + self.mlp = MLP() + try: + cached_path = hf_hub_download(model_id, model_filename) + except EntryNotFoundError: + cached_path = os.path.join(model_id, model_filename) + state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True) + self.mlp.load_state_dict(state_dict) + self.dtype = dtype + self.eval() + + def __call__(self, images): + device = next(self.parameters()).device + images = torchvision.transforms.Resize(self.target_size)(images) + images = self.normalize(images).to(self.dtype).to(device) + embed = self.clip.get_image_features(pixel_values=images) + # normalize embedding + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + reward = self.mlp(embed).squeeze(1) + return reward + + +def aesthetic_scorer(hub_model_id, model_filename): + scorer = AestheticScorer( + model_id=hub_model_id, + model_filename=model_filename, + dtype=torch.float32, + ) + if is_torch_npu_available(): + scorer = scorer.npu() + elif is_torch_xpu_available(): + scorer = scorer.xpu() + else: + scorer = scorer.cuda() + + def _fn(images, prompts, metadata): + images = (images).clamp(0, 1) + scores = scorer(images) + return scores, {} + + return _fn diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py new file mode 100644 index 0000000000000000000000000000000000000000..07c0e3f26dd52190f97a15d6c4a9216ceb33cfa6 --- /dev/null +++ b/trl/models/modeling_base.py @@ -0,0 +1,731 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn as nn +from accelerate import PartialState +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HFValidationError, + LocalEntryNotFoundError, + RepositoryNotFoundError, +) +from safetensors.torch import load_file as safe_load_file +from transformers import GenerationMixin, PreTrainedModel, is_torch_npu_available, is_torch_xpu_available +from transformers.utils import is_peft_available + + +if is_peft_available(): + from peft import ( + PeftConfig, + PeftModel, + PeftModelForCausalLM, + PeftModelForSeq2SeqLM, + PromptLearningConfig, + get_peft_model, + prepare_model_for_kbit_training, + ) + + +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + +LAYER_PATTERNS = [ + "transformer.h.{layer}", + "model.decoder.layers.{layer}", + "gpt_neox.layers.{layer}", + "model.layers.{layer}", +] + + +class PreTrainedModelWrapper(nn.Module): + r""" + A wrapper class around a (`transformers.PreTrainedModel`) to be compatible with the + (`~transformers.PreTrained`) class in order to keep some attributes and methods of the + (`~transformers.PreTrainedModel`) class. + + Attributes: + pretrained_model (`transformers.PreTrainedModel`): + The model to be wrapped. + parent_class (`transformers.PreTrainedModel`): + The parent class of the model to be wrapped. + supported_args (`list`): + The list of arguments that are supported by the wrapper class. + """ + + transformers_parent_class = None + supported_args = None + supported_modules = ("v_head",) + supported_rm_modules = ("score",) + supported_pretrained_model_architectures = ( + (PreTrainedModel) + if not is_peft_available() + else (PreTrainedModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM) + ) + + def __init__( + self, pretrained_model=None, score_module=None, supports_rm_adapter=False, rm_adapter_name=None, **kwargs + ): + super().__init__() + self.pretrained_model = pretrained_model + + self.config = pretrained_model.config + self.prepare_inputs_for_generation = pretrained_model.prepare_inputs_for_generation + self.is_loaded_in_8bit = getattr(pretrained_model, "is_loaded_in_8bit", False) + self.is_loaded_in_4bit = getattr(pretrained_model, "is_loaded_in_4bit", False) + self.is_sequential_parallel = False + + if hasattr(pretrained_model, "gradient_checkpointing_disable"): + self.gradient_checkpointing_disable = pretrained_model.gradient_checkpointing_disable + + if hasattr(pretrained_model, "gradient_checkpointing_enable"): + self.gradient_checkpointing_enable = pretrained_model.gradient_checkpointing_enable + + if hasattr(pretrained_model, "enable_input_require_grads"): + self.enable_input_require_grads = pretrained_model.enable_input_require_grads + + self.supports_rm_adapter = supports_rm_adapter + self.rm_adapter_name = rm_adapter_name + self.policy_adapter_name = "default" + if score_module is not None: + self.score = score_module + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + r""" + Instantiates a new model from a pretrained model from `transformers`. The + pretrained model is loaded using the `from_pretrained` method of the + `transformers.PreTrainedModel` class. The arguments that are specific to the + `transformers.PreTrainedModel` class are passed along this method and filtered + out from the `kwargs` argument. + + Args: + pretrained_model_name_or_path (`str` or `transformers.PreTrainedModel`): + The path to the pretrained model or its name. + *model_args (`list`, *optional*)): + Additional positional arguments passed along to the underlying model's + `from_pretrained` method. + **kwargs (`dict`, *optional*): + Additional keyword arguments passed along to the underlying model's + `from_pretrained` method. We also pre-process the kwargs to extract + the arguments that are specific to the `transformers.PreTrainedModel` + class and the arguments that are specific to trl models. The kwargs + also support `prepare_model_for_kbit_training` arguments from + `peft` library. + """ + if kwargs is not None: + peft_config = kwargs.pop("peft_config", None) + reward_adapter = kwargs.pop("reward_adapter", None) + reward_adapter_name = kwargs.pop("reward_adapter_name", "reward_adapter") + is_trainable = kwargs.pop("is_trainable", False) + trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs) + token = pretrained_kwargs.get("token", None) + else: + peft_config = None + is_trainable = False + trl_model_args = {} + pretrained_kwargs = {} + peft_quantization_kwargs = {} + token = None + + if reward_adapter is not None and not isinstance(reward_adapter, str): + raise ValueError( + "The `reward_adapter` argument should be a string representing the name of local path or the Hub id to the Reward Modeling adapter." + ) + + is_peft_model = False + + current_device = cls._get_current_device() + if isinstance(pretrained_model_name_or_path, str): + is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False + is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False + else: + is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False) + is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False) + + if (is_loaded_in_8bit or is_loaded_in_4bit) and "device_map" not in pretrained_kwargs: + # warn users + logging.warning( + "The `device_map` argument is not provided. We will override the device_map argument." + " to set the entire" + " model on the current device. If you want to set the model on multiple devices, please provide" + " a custom `device_map` argument." + ) + pretrained_kwargs["device_map"] = {"": current_device} + + if is_peft_available() and peft_config is not None and not isinstance(peft_config, PeftConfig): + raise ValueError("The `peft_config` argument should be an instance of `peft.PeftConfig` class.") + + # First, load the pre-trained model using the parent-class + # either `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` + if isinstance(pretrained_model_name_or_path, str): + if is_peft_available(): + try: + # If there is a trained peft adapter in the hub, load its config. + remote_adapter_config = hf_hub_download( + pretrained_model_name_or_path, + "adapter_config.json", + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + remote_adapter_config = None + else: + remote_adapter_config = None + + local_adapter_present = os.path.exists(os.path.join(pretrained_model_name_or_path, "adapter_config.json")) + + if (local_adapter_present or remote_adapter_config is not None) and is_peft_available(): + if peft_config is not None: + logging.warning( + "`peft_config` argument ignored since a peft config file was found in " + f"{pretrained_model_name_or_path}" + ) + + # Load the trained peft adapter config + if local_adapter_present: + trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path) + else: + remote_adapter_dir = os.path.dirname(remote_adapter_config) + trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir) + + # Load the pretrained base model + pretrained_model = cls.transformers_parent_class.from_pretrained( + trained_adapter_config.base_model_name_or_path, *model_args, **pretrained_kwargs + ) + + # Wrap the pretrained model with the trained peft adapter + pretrained_model = PeftModel.from_pretrained( + pretrained_model, pretrained_model_name_or_path, is_trainable=is_trainable, token=token + ) + logging.info("Trained peft adapter loaded") + else: + pretrained_model = cls.transformers_parent_class.from_pretrained( + pretrained_model_name_or_path, *model_args, **pretrained_kwargs + ) + + if peft_config is not None: + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + + elif isinstance(pretrained_model_name_or_path, cls.supported_pretrained_model_architectures): + pretrained_model = pretrained_model_name_or_path + + if peft_config is not None and isinstance(pretrained_model, PreTrainedModel): + # Initialize a new peft adapter with the given config + if is_loaded_in_8bit or is_loaded_in_4bit: + pretrained_model = prepare_model_for_kbit_training( + pretrained_model, + **peft_quantization_kwargs, + ) + pretrained_model = get_peft_model(pretrained_model, peft_config) + logging.info("peft adapter initialised") + else: + raise ValueError( + "pretrained_model_name_or_path should be a string or a PreTrainedModel, " + f"but is {type(pretrained_model_name_or_path)}" + ) + + if is_peft_available(): + if isinstance(pretrained_model, PeftModel): + is_peft_model = True + # for backward compatibility + if hasattr(pretrained_model, "active_peft_config") and isinstance( + pretrained_model.active_peft_config, PromptLearningConfig + ): + raise ValueError("PromptLearningConfig is not supported for PPO training.") + + # Add reward modeling adapter if specified + if not is_peft_model and reward_adapter is not None: + raise ValueError("reward_adapter can only be used with a PeftModel. ") + elif is_peft_model and reward_adapter is not None: + score_module = cls.add_and_load_reward_modeling_adapter( + pretrained_model, reward_adapter, reward_adapter_name, token=token + ) + multi_adapter_args = { + "score_module": score_module, + "supports_rm_adapter": True, + "rm_adapter_name": reward_adapter_name, + } + else: + multi_adapter_args = {"supports_rm_adapter": False} + + # Then, create the full model by instantiating the wrapper class + model = cls(pretrained_model, **multi_adapter_args, **trl_model_args) + + # if resume_training, load the state_dict again - this is ok since the + # state_dict is removed from the model after loading it. + is_resuming_training = True + if isinstance(pretrained_model_name_or_path, str): + safe_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors") + filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin") + + sharded_index_filename = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin.index.json") + safe_sharded_index_filename = os.path.join(pretrained_model_name_or_path, "model.safetensors.index.json") + is_sharded = False + use_safe = os.path.exists(safe_filename) + + if not (os.path.exists(filename) or os.path.exists(safe_filename)): + # Try with `pytorch_model.bin` + filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + sharded_index_filename, + token=token, + ) + # Try with safetensors + if filename is None and files_to_download is None: + safe_filename, files_to_download, is_sharded, is_resuming_training = cls._get_checkpoint_from_hub( + pretrained_model, + pretrained_model_name_or_path, + safe_sharded_index_filename, + token=token, + model_name="model.safetensors", + model_index_name="model.safetensors.index.json", + ) + use_safe = True + else: + use_safe = False + + loading_func = safe_load_file if use_safe else torch.load + load_kwargs = {} if use_safe else {"map_location": "cpu", "weights_only": True} + + if is_resuming_training: + if is_sharded: + # download each file and add it to the state_dict + state_dict = {} + + for shard_file in files_to_download: + filename = hf_hub_download( + pretrained_model_name_or_path, + shard_file, + token=token, + ) + state_dict.update(loading_func(filename, **load_kwargs)) + else: + state_dict = loading_func(filename if not use_safe else safe_filename, **load_kwargs) + + else: + state_dict = pretrained_model_name_or_path.state_dict() + + model.is_peft_model = is_peft_model + model.current_device = current_device + + if is_resuming_training: + model.post_init(state_dict=state_dict) + + return model + + @classmethod + def _get_checkpoint_from_hub( + cls, + pretrained_model, + pretrained_model_name_or_path, + index_filename, + token=None, + model_name="pytorch_model.bin", + model_index_name="pytorch_model.bin.index.json", + ): + files_to_download = None + filename = None + is_resuming_training = True + is_sharded = False + + try: + filename = hf_hub_download( + pretrained_model_name_or_path, + model_name, + token=token, + ) + # sharded + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + if os.path.exists(index_filename): + index_file_name = index_filename + else: + try: + index_file_name = hf_hub_download( + pretrained_model_name_or_path, + model_index_name, + token=token, + ) + except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError, RepositoryNotFoundError): + # not continue training, do not have v_head weight + is_resuming_training = False + logging.warning( + f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', " + f"and no v_head weight is found. This IS expected if you are not resuming PPO training." + ) + # load json + if is_resuming_training: + with open(index_file_name) as f: + index = json.load(f) + # check filename with `v_head` or any known extra module: + files_to_download = set() + for k, v in index["weight_map"].items(): + if any(module in k for module in cls.supported_modules): + files_to_download.add(v) + is_sharded = True + + return filename, files_to_download, is_sharded, is_resuming_training + + @classmethod + def _get_current_device(cls): + r""" + Get the current device. For GPU & XPU, we return the local process index using the `accelerate.PartialState` + object to handle corner cases when running scripts in distributed environments. + + Returns: + current_device (`Union[int, str]`): + The current device. + """ + state = PartialState() + if torch.cuda.is_available() or is_torch_xpu_available(): + return state.local_process_index + elif is_torch_npu_available(): + return f"npu:{state.local_process_index}" + else: + return "cpu" + + @classmethod + def _split_kwargs(cls, kwargs): + """ + Separate the kwargs from the arguments that we support inside + `supported_args` and the ones that we don't. + """ + check_peft_kwargs = False + + if is_peft_available(): + from peft import prepare_model_for_kbit_training + + check_peft_kwargs = True + + supported_kwargs = {} + unsupported_kwargs = {} + peft_kwargs = {} + + for key, value in kwargs.items(): + if key in cls.supported_args: + supported_kwargs[key] = value + else: + unsupported_kwargs[key] = value + + if check_peft_kwargs: + if key in prepare_model_for_kbit_training.__code__.co_varnames: + peft_kwargs[key] = value + if key in unsupported_kwargs: + unsupported_kwargs.pop(key) + + return supported_kwargs, unsupported_kwargs, peft_kwargs + + @classmethod + def add_and_load_reward_modeling_adapter( + cls, pretrained_model, adapter_model_id, adapter_name="reward_model_adapter", token=None + ): + r""" + Add and load a reward modeling adapter. This method can only be used if the + model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id` + argument, pointing to the id of the reward modeling adapter. The latest needs also to contain the + score head in order to produce the reward. + """ + pretrained_model.load_adapter(adapter_model_id, adapter_name, is_trainable=False) + pretrained_model.train() + + filename = os.path.join(adapter_model_id, "adapter_model.bin") + safe_loading = False + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.bin", + token=token, + ) + except Exception: + filename = os.path.join(adapter_model_id, "adapter_model.safetensors") + safe_loading = True + if not os.path.exists(filename): + try: + local_filename = hf_hub_download( + adapter_model_id, + "adapter_model.safetensors", + token=token, + ) + except Exception as exc: + raise ValueError( + "Could not find adapter model in the Hub, make sure you have the correct adapter model id." + ) from exc + else: + local_filename = filename + else: + local_filename = filename + + loading_func = safe_load_file if safe_loading else torch.load + load_kwargs = {} if safe_loading else {"map_location": "cpu", "weights_only": True} + + adapter_state_dict = loading_func(local_filename, **load_kwargs) + + for score_name_candidate in cls.supported_rm_modules: + if any(score_name_candidate in name for name in adapter_state_dict.keys()): + score_name = score_name_candidate + # we have found the correct head name and can break + break + + score_dict = {} + + for name, param in adapter_state_dict.items(): + if score_name in name: + key_name = ".".join(name.split(".")[-1:]) + score_dict[key_name] = param.to(cls._get_current_device()) + + num_labels, hidden_dim = score_dict["weight"].shape + has_bias = any("bias" in name for name in adapter_state_dict.keys()) + + score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to( + device=cls._get_current_device(), + dtype=pretrained_model.dtype, + ) + score.load_state_dict(score_dict) + for param in score.parameters(): + param.requires_grad = False + + return score + + def push_to_hub(self, *args, **kwargs): + r""" + Push the pretrained model to the hub. This method is a wrapper around + `transformers.PreTrainedModel.push_to_hub`. Please refer to the documentation + of `transformers.PreTrainedModel.push_to_hub` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `push_to_hub` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `push_to_hub` method. + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + r""" + Save the pretrained model to a directory. This method is a wrapper around + `transformers.PreTrainedModel.save_pretrained`. Please refer to the documentation + of `transformers.PreTrainedModel.save_pretrained` for more information. + + Args: + *args (`list`, *optional*): + Positional arguments passed along to the underlying model's + `save_pretrained` method. + **kwargs (`dict`, *optional*): + Keyword arguments passed along to the underlying model's + `save_pretrained` method. + """ + state_dict = kwargs.get("state_dict") + if state_dict is None: + state_dict = self.state_dict() + kwargs["state_dict"] = state_dict + + # if it is a peft model only save the `v_head` state_dict and + # pop the `state_dict` from the kwargs to avoid slient bugs with `peft` + if self.is_peft_model: + save_path = args[0] + save_path = os.path.join(save_path, "pytorch_model.bin") + torch.save(state_dict, save_path) + _ = kwargs.pop("state_dict", None) + + return self.pretrained_model.save_pretrained(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Return the state_dict of the pretrained model. + """ + raise NotImplementedError + + def post_init(self, *args, **kwargs): + r""" + Post initialization method. This method is called after the model is + instantiated and loaded from a checkpoint. It can be used to perform + additional operations such as loading the state_dict. + """ + raise NotImplementedError + + def compute_reward_score(self, input_ids, attention_mask=None, **kwargs): + r""" + Computes the reward score for a given input. The method has first to enable the adapter + and then compute the reward score. After that the model disables the reward modeling + adapter and enables the default ppo adapter again. + """ + if not self.supports_rm_adapter: + raise ValueError("This model does not support reward modeling adapter.") + + # enable rm adapter + self.pretrained_model.set_adapter(self.rm_adapter_name) + self.pretrained_model.eval() + + with torch.no_grad(): + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + **kwargs, + ) + + last_hidden_states = base_model_output.hidden_states[-1] + scores = self.score(last_hidden_states) + + self.pretrained_model.set_adapter(self.policy_adapter_name) + self.pretrained_model.eval() + + return scores + + +def create_reference_model( + model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None +) -> PreTrainedModelWrapper: + """ + Creates a static reference copy of a model. Note that model will be in `.eval()` mode. + + Args: + model (`PreTrainedModelWrapper`): The model to be copied. + num_shared_layers (`int`, *optional*): The number of initial layers that are shared between both models and kept frozen. + pattern (`str`, *optional*): The shared layers are selected with a string pattern + (e.g. "transformer.h.{layer}" for GPT2) and if a custom pattern is necessary it can be passed here. + + Returns: + `PreTrainedModelWrapper` + """ + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`." + ) + + parameter_names = [n for n, _ in model.named_parameters()] + ref_model = deepcopy(model) + + # if no layers are shared, return copy of model + if num_shared_layers is None: + for param_name in parameter_names: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + return ref_model.eval() + + # identify layer name pattern + if pattern is not None: + pattern = pattern.format(layer=num_shared_layers) + else: + for pattern_candidate in LAYER_PATTERNS: + pattern_candidate = pattern_candidate.format(layer=num_shared_layers) + if any(pattern_candidate in name for name in parameter_names): + pattern = pattern_candidate + break + + if pattern is None: + raise ValueError("Layer pattern could not be matched.") + + # divide parameters in shared and unshared parameter lists + shared_param_list = [] + unshared_param_list = [] + + shared_parameter = True + for name, _param in model.named_parameters(): + if pattern in name: + shared_parameter = False + if shared_parameter: + shared_param_list.append(name) + else: + unshared_param_list.append(name) + + # create reference of the original parameter if they are shared + for param_name in shared_param_list: + param = model.get_parameter(param_name) + param.requires_grad = False + + _ref_param = ref_model.get_parameter(param_name) + + # for all other parameters just make sure they don't use gradients + for param_name in unshared_param_list: + param = ref_model.get_parameter(param_name) + param.requires_grad = False + + if pattern is not None and len(unshared_param_list) == 0: + logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") + + return ref_model.eval() + + +class GeometricMixtureWrapper(GenerationMixin): + r""" + Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture. + + Args: + model (`PreTrainedModel`): The model to be wrapped. + ref_model (`PreTrainedModel`): The reference model. + generation_config (`GenerationConfig`): The generation config. + mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient. + """ + + main_input_name = "input_ids" + _supports_cache_class = False + _supports_static_cache = False + + def __init__(self, model, ref_model, generation_config, mixture_coef=0.5, device=None): + super().__init__() + + self.model = model + self.config = model.config + self.ref_model = ref_model + self.generation_config = generation_config + self.mixture_coef = mixture_coef + self.device = device + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.inference_mode() + def forward(self, *args, **kwargs): + model_outputs = self.model(*args, **kwargs) + model_logits = model_outputs.logits + ref_model_logits = self.ref_model(*args, **kwargs).logits + + model_outputs.logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_logits + (1 - self.mixture_coef) * model_logits, dim=-1 + ) + + return model_outputs + + def prepare_inputs_for_generation(self, *args, **kwargs): + # turn off cache in the generation config + kwargs["use_cache"] = False + model_inputs = self.model.prepare_inputs_for_generation(*args, **kwargs) + _ = self.ref_model.prepare_inputs_for_generation(*args, **kwargs) + + return model_inputs + + def _validate_model_class(self): + self.model._validate_model_class() + + def _validate_model_kwargs(self, model_kwargs): + return self.model._validate_model_kwargs(model_kwargs) diff --git a/trl/models/modeling_sd_base.py b/trl/models/modeling_sd_base.py new file mode 100644 index 0000000000000000000000000000000000000000..82a951935498e15eef747e90b6c0d92c922b1a65 --- /dev/null +++ b/trl/models/modeling_sd_base.py @@ -0,0 +1,911 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +import random +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import numpy as np +import torch +import torch.utils.checkpoint as checkpoint +from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg +from transformers.utils import is_peft_available + +from ..core import randn_tensor +from .sd_utils import convert_state_dict_to_diffusers + + +if is_peft_available(): + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict + + +@dataclass +class DDPOPipelineOutput: + """ + Output class for the diffusers pipeline to be finetuned with the DDPO trainer + + Args: + images (`torch.Tensor`): + The generated images. + latents (`list[torch.Tensor]`): + The latents used to generate the images. + log_probs (`list[torch.Tensor]`): + The log probabilities of the latents. + + """ + + images: torch.Tensor + latents: torch.Tensor + log_probs: torch.Tensor + + +@dataclass +class DDPOSchedulerOutput: + """ + Output class for the diffusers scheduler to be finetuned with the DDPO trainer + + Args: + latents (`torch.Tensor`): + Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)` + log_probs (`torch.Tensor`): + Log probability of the above mentioned sample. Shape: `(batch_size)` + """ + + latents: torch.Tensor + log_probs: torch.Tensor + + +class DDPOStableDiffusionPipeline: + """ + Main class for the diffusers pipeline to be finetuned with the DDPO trainer + """ + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + raise NotImplementedError + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + raise NotImplementedError + + @property + def unet(self): + """ + Returns the 2d U-Net model used for diffusion. + """ + raise NotImplementedError + + @property + def vae(self): + """ + Returns the Variational Autoencoder model used from mapping images to and from the latent space + """ + raise NotImplementedError + + @property + def tokenizer(self): + """ + Returns the tokenizer used for tokenizing text inputs + """ + raise NotImplementedError + + @property + def scheduler(self): + """ + Returns the scheduler associated with the pipeline used for the diffusion process + """ + raise NotImplementedError + + @property + def text_encoder(self): + """ + Returns the text encoder used for encoding text inputs + """ + raise NotImplementedError + + @property + def autocast(self): + """ + Returns the autocast context manager + """ + raise NotImplementedError + + def set_progress_bar_config(self, *args, **kwargs): + """ + Sets the progress bar config for the pipeline + """ + raise NotImplementedError + + def save_pretrained(self, *args, **kwargs): + """ + Saves all of the model weights + """ + raise NotImplementedError + + def get_trainable_layers(self, *args, **kwargs): + """ + Returns the trainable parameters of the pipeline + """ + raise NotImplementedError + + def save_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state + """ + raise NotImplementedError + + def load_checkpoint(self, *args, **kwargs): + """ + Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state + """ + raise NotImplementedError + + +def _left_broadcast(input_tensor, shape): + """ + As opposed to the default direction of broadcasting (right to left), this function broadcasts + from left to right + Args: + input_tensor (`torch.FloatTensor`): is the tensor to broadcast + shape (`tuple[int]`): is the shape to broadcast to + """ + input_ndim = input_tensor.ndim + if input_ndim > len(shape): + raise ValueError( + "The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to" + ) + return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape) + + +def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ).to(timestep.device) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + +def scheduler_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> DDPOSchedulerOutput: + """ + + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://huggingface.co/papers/2210.05559) + + Returns: + `DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample + """ + + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://huggingface.co/papers/2010.02502 + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + # to prevent OOB on gather + prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) + alpha_prod_t_prev = torch.where( + prev_timestep.cpu() >= 0, + self.alphas_cumprod.gather(0, prev_timestep.cpu()), + self.final_alpha_cumprod, + ) + alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) + alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://huggingface.co/papers/2010.02502 + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = _get_variance(self, timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://huggingface.co/papers/2010.02502 + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://huggingface.co/papers/2010.02502 + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(np.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob) + + +# 1. The output type for call is different as the logprobs are now returned +# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output +@torch.no_grad() +def pipeline_step( + self, + prompt: Optional[Union[str, list[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, list[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. Args: prompt (`str` or `list[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + + +def pipeline_step_with_grad( + pipeline, + prompt: Optional[Union[str, list[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + truncated_backprop: bool = True, + truncated_backprop_rand: bool = True, + gradient_checkpoint: bool = True, + truncated_backprop_timestep: int = 49, + truncated_rand_backprop_minmax: tuple = (0, 50), + negative_prompt: Optional[Union[str, list[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function to get RGB image with gradients attached to the model weights. + + Args: + prompt (`str` or `list[str]`, *optional*, defaults to `None`): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. + height (`int`, *optional*, defaults to `pipeline.unet.config.sample_size * pipeline.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `pipeline.unet.config.sample_size * pipeline.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to `7.5`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://huggingface.co/papers/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + truncated_backprop (`bool`, *optional*, defaults to True): + Truncated Backpropation to fixed timesteps, helps prevent collapse during diffusion reward training as shown in AlignProp (https://huggingface.co/papers/2310.03739). + truncated_backprop_rand (`bool`, *optional*, defaults to True): + Truncated Randomized Backpropation randomizes truncation to different diffusion timesteps, this helps prevent collapse during diffusion reward training as shown in AlignProp (https://huggingface.co/papers/2310.03739). + Enabling truncated_backprop_rand allows adapting earlier timesteps in diffusion while not resulting in a collapse. + gradient_checkpoint (`bool`, *optional*, defaults to True): + Adds gradient checkpointing to Unet forward pass. Reduces GPU memory consumption while slightly increasing the training time. + truncated_backprop_timestep (`int`, *optional*, defaults to 49): + Absolute timestep to which the gradients are being backpropagated. Higher number reduces the memory usage and reduces the chances of collapse. + While a lower value, allows more semantic changes in the diffusion generations, as the earlier diffusion timesteps are getting updated. + However it also increases the chances of collapse. + truncated_rand_backprop_minmax (`Tuple`, *optional*, defaults to (0,50)): + Range for randomized backprop. Here the value at 0 index indicates the earlier diffusion timestep to update (closer to noise), while the value + at index 1 indicates the later diffusion timestep to update. + negative_prompt (`str` or `list[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `list[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `pipeline.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + `DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities + """ + # 0. Default height and width to unet + height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor + width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor + + with torch.no_grad(): + # 1. Check inputs. Raise error if not correct + pipeline.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = pipeline._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = pipeline._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + pipeline.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = pipeline.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = pipeline.unet.config.in_channels + latents = pipeline.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order + all_latents = [latents] + all_log_probs = [] + with pipeline.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + if gradient_checkpoint: + noise_pred = checkpoint.checkpoint( + pipeline.unet, + latent_model_input, + t, + prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + use_reentrant=False, + )[0] + else: + noise_pred = pipeline.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # truncating backpropagation is critical for preventing overoptimization (https://huggingface.co/papers/2304.05977). + if truncated_backprop: + # Randomized truncation randomizes the truncation process (https://huggingface.co/papers/2310.03739) + # the range of truncation is defined by truncated_rand_backprop_minmax + # Setting truncated_rand_backprop_minmax[0] to be low will allow the model to update earlier timesteps in the diffusion chain, while setitng it high will reduce the memory usage. + if truncated_backprop_rand: + rand_timestep = random.randint( + truncated_rand_backprop_minmax[0], truncated_rand_backprop_minmax[1] + ) + if i < rand_timestep: + noise_pred = noise_pred.detach() + else: + # fixed truncation process + if i < truncated_backprop_timestep: + noise_pred = noise_pred.detach() + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://huggingface.co/papers/2305.08891 + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler_step(pipeline.scheduler, noise_pred, t, latents, eta) + latents = scheduler_output.latents + log_prob = scheduler_output.log_probs + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = pipeline.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None: + pipeline.final_offload_hook.offload() + + return DDPOPipelineOutput(image, all_latents, all_log_probs) + + +class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline): + def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True): + self.sd_pipeline = StableDiffusionPipeline.from_pretrained( + pretrained_model_name, revision=pretrained_model_revision + ) + + self.use_lora = use_lora + self.pretrained_model = pretrained_model_name + self.pretrained_revision = pretrained_model_revision + + try: + self.sd_pipeline.load_lora_weights( + pretrained_model_name, + weight_name="pytorch_lora_weights.safetensors", + revision=pretrained_model_revision, + ) + self.use_lora = True + except OSError: + if use_lora: + warnings.warn( + "Trying to load LoRA weights but no LoRA weights found. Set `use_lora=False` or check that " + "`pytorch_lora_weights.safetensors` exists in the model folder.", + UserWarning, + ) + + self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config) + self.sd_pipeline.safety_checker = None + + # memory optimization + self.sd_pipeline.vae.requires_grad_(False) + self.sd_pipeline.text_encoder.requires_grad_(False) + self.sd_pipeline.unet.requires_grad_(not self.use_lora) + + def __call__(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step(self.sd_pipeline, *args, **kwargs) + + def rgb_with_grad(self, *args, **kwargs) -> DDPOPipelineOutput: + return pipeline_step_with_grad(self.sd_pipeline, *args, **kwargs) + + def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput: + return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs) + + @property + def unet(self): + return self.sd_pipeline.unet + + @property + def vae(self): + return self.sd_pipeline.vae + + @property + def tokenizer(self): + return self.sd_pipeline.tokenizer + + @property + def scheduler(self): + return self.sd_pipeline.scheduler + + @property + def text_encoder(self): + return self.sd_pipeline.text_encoder + + @property + def autocast(self): + return contextlib.nullcontext if self.use_lora else None + + def save_pretrained(self, output_dir): + if self.use_lora: + state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(self.sd_pipeline.unet)) + self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) + self.sd_pipeline.save_pretrained(output_dir) + + def set_progress_bar_config(self, *args, **kwargs): + self.sd_pipeline.set_progress_bar_config(*args, **kwargs) + + def get_trainable_layers(self): + if self.use_lora: + lora_config = LoraConfig( + r=4, + lora_alpha=4, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + self.sd_pipeline.unet.add_adapter(lora_config) + + # To avoid accelerate unscaling problems in FP16. + for param in self.sd_pipeline.unet.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + return self.sd_pipeline.unet + else: + return self.sd_pipeline.unet + + def save_checkpoint(self, models, weights, output_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora and hasattr(models[0], "peft_config") and getattr(models[0], "peft_config", None) is not None: + state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(models[0])) + self.sd_pipeline.save_lora_weights(save_directory=output_dir, unet_lora_layers=state_dict) + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + models[0].save_pretrained(os.path.join(output_dir, "unet")) + else: + raise ValueError(f"Unknown model type {type(models[0])}") + + def load_checkpoint(self, models, input_dir): + if len(models) != 1: + raise ValueError("Given how the trainable params were set, this should be of length 1") + if self.use_lora: + lora_state_dict, network_alphas = self.sd_pipeline.lora_state_dict( + input_dir, weight_name="pytorch_lora_weights.safetensors" + ) + self.sd_pipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=models[0]) + + elif not self.use_lora and isinstance(models[0], UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + models[0].register_to_config(**load_model.config) + models[0].load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"Unknown model type {type(models[0])}") diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py new file mode 100644 index 0000000000000000000000000000000000000000..e5cd18d2b97b0c33cf9e68a1c0d6edaf658b3a64 --- /dev/null +++ b/trl/models/modeling_value_head.py @@ -0,0 +1,437 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, is_torch_npu_available, is_torch_xpu_available + +from .modeling_base import PreTrainedModelWrapper + + +class ValueHead(nn.Module): + r""" + The ValueHead class implements a head for GPT2 that returns a scalar for each output token. + """ + + def __init__(self, config, **kwargs): + super().__init__() + if not hasattr(config, "summary_dropout_prob"): + summary_dropout_prob = kwargs.pop("summary_dropout_prob", 0.1) + else: + summary_dropout_prob = config.summary_dropout_prob + + self.dropout = nn.Dropout(summary_dropout_prob) if summary_dropout_prob else nn.Identity() + + # some models such as OPT have a projection layer before the word embeddings - e.g. OPT-350m + if hasattr(config, "hidden_size"): + hidden_size = config.hidden_size + if hasattr(config, "word_embed_proj_dim"): + hidden_size = config.word_embed_proj_dim + elif hasattr(config, "is_encoder_decoder"): + if config.is_encoder_decoder and hasattr(config, "decoder"): + if hasattr(config.decoder, "hidden_size"): + hidden_size = config.decoder.hidden_size + + self.summary = nn.Linear(hidden_size, 1) + + self.flatten = nn.Flatten() + + def forward(self, hidden_states): + output = self.dropout(hidden_states) + + # For now force upcast in fp32 if needed. Let's keep the + # output in fp32 for numerical stability. + if output.dtype != self.summary.weight.dtype: + output = output.to(self.summary.weight.dtype) + + output = self.summary(output) + return output + + +class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper): + r""" + An autoregressive model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained`, `push_to_hub` and `generate`. To call a method of the wrapped + model, simply manipulate the `pretrained_model` attribute of this class. + + Class attributes: + - **transformers_parent_class** (`transformers.PreTrainedModel`) -- The parent class of the wrapped model. This + should be set to `transformers.AutoModelForCausalLM` for this class. + - **supported_args** (`tuple`) -- A tuple of strings that are used to identify the arguments that are supported + by the `ValueHead` class. Currently, the supported args are: + - **summary_dropout_prob** (`float`, `optional`, defaults to `None`) -- The dropout probability for the + `ValueHead` class. + - **v_head_initializer_range** (`float`, `optional`, defaults to `0.2`) -- The initializer range for the + `ValueHead` if a specific initialization strategy is selected. + - **v_head_init_strategy** (`str`, `optional`, defaults to `None`) -- The initialization strategy for the + `ValueHead`. Currently, the supported strategies are: + - **`None`** -- Initializes the weights of the `ValueHead` with a random distribution. This is the default + strategy. + - **"normal"** -- Initializes the weights of the `ValueHead` with a normal distribution. + """ + + transformers_parent_class = AutoModelForCausalLM + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + r""" + Initializes the model. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForCausalLM` class. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. + """ + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + self._init_weights(**v_head_kwargs) + + def _init_weights(self, **kwargs): + r""" + Initializes the weights of the value head. The default initialization strategy is random. + Users can pass a different initialization strategy by passing the `v_head_init_strategy` argument + when calling `.from_pretrained`. Supported strategies are: + - `normal`: initializes the weights with a normal distribution. + + Args: + **kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the `ValueHead` class. These arguments + can contain the `v_head_init_strategy` argument as well as the `v_head_initializer_range` + argument. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + **kwargs, + ): + r""" + Applies a forward pass to the wrapped model and returns the logits of the value head. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`): + Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. + kwargs (`dict`, `optional`): + Additional keyword arguments, that are passed to the wrapped model. + """ + kwargs["output_hidden_states"] = True # this had already been set in the LORA / PEFT examples + kwargs["past_key_values"] = past_key_values + + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + + last_hidden_state = base_model_output.hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + if last_hidden_state.device != self.v_head.summary.weight.device: + last_hidden_state = last_hidden_state.to(self.v_head.summary.weight.device) + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + A simple wrapper around the `generate` method of the wrapped model. + Please refer to the [`generate`](https://huggingface.co/docs/transformers/internal/generation_utils) + method of the wrapped model for more information about the supported arguments. + + Args: + *args (`list`, *optional*): + Positional arguments passed to the `generate` method of the wrapped model. + **kwargs (`dict`, *optional*): + Keyword arguments passed to the `generate` method of the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + self.pretrained_model.v_head = self.v_head + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + first_device = list(set(self.pretrained_model.hf_device_map.values()))[0] + if isinstance(first_device, int): + if is_torch_npu_available(): + first_device = f"npu:{first_device}" + elif is_torch_xpu_available(): + first_device = f"xpu:{first_device}" + else: + first_device = f"cuda:{first_device}" + self.v_head = self.v_head.to(first_device) + + def set_device_hook(module, input, outputs): + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(first_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + + self.is_sequential_parallel = True + + +class AutoModelForSeq2SeqLMWithValueHead(PreTrainedModelWrapper): + r""" + A seq2seq model with a value head in addition to the language model head. + This class inherits from `~trl.PreTrainedModelWrapper` and wraps a + `transformers.PreTrainedModel` class. The wrapper class supports classic functions + such as `from_pretrained` and `push_to_hub` and also provides some additional + functionalities such as `generate`. + + Args: + pretrained_model (`transformers.PreTrainedModel`): + The model to wrap. It should be a causal language model such as GPT2. + or any model mapped inside the `AutoModelForSeq2SeqLM` class. + kwargs: + Additional keyword arguments passed along to the `ValueHead` class. + """ + + transformers_parent_class = AutoModelForSeq2SeqLM + lm_head_namings = ["lm_head", "embed_out", "output_projection"] + supported_args = ( + "summary_dropout_prob", + "v_head_initializer_range", + "v_head_init_strategy", + ) + + def __init__(self, pretrained_model, **kwargs): + super().__init__(pretrained_model, **kwargs) + v_head_kwargs, _, _ = self._split_kwargs(kwargs) + self.is_encoder_decoder = True + + if not self._has_lm_head(): + raise ValueError("The model does not have a language model head, please use a model that has one.") + + self.v_head = ValueHead(self.pretrained_model.config, **v_head_kwargs) + + self._init_weights(**v_head_kwargs) + + def _has_lm_head(self): + # check module names of all modules inside `pretrained_model` to find the language model head + for name, _module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + return True + return False + + def post_init(self, state_dict): + r""" + We add the state dictionary of the value head to the state dictionary of the wrapped model + by prepending the key with `v_head.`. This function removes the `v_head.` prefix from the + keys of the value head state dictionary. + """ + for k in list(state_dict.keys()): + if "v_head." in k: + state_dict[k.replace("v_head.", "")] = state_dict.pop(k) + self.v_head.load_state_dict(state_dict, strict=False) + del state_dict + + if hasattr(self.pretrained_model, "hf_device_map"): + if ( + "cpu" in self.pretrained_model.hf_device_map.values() + or "disk" in self.pretrained_model.hf_device_map.values() + ): + raise ValueError( + "The model is offloaded on CPU or disk - CPU & disk offloading is not supported for ValueHead models." + ) + + # get the lm_head device + for name, module in self.pretrained_model.named_modules(): + if any(attribute in name for attribute in self.lm_head_namings): + lm_head_device = module.weight.device + break + + # put v_head on the same device as the lm_head to avoid issues + self.v_head = self.v_head.to(lm_head_device) + + def set_device_hook(module, input, outputs): + r""" + A hook that sets the device of the output of the model to the device of the first + parameter of the model. + + Args: + module (`nn.Module`): + The module to which the hook is attached. + input (`tuple`): + The input to the module. + outputs (`tuple`): + The output of the module. + """ + new_output = () + for output in outputs: + if isinstance(output, torch.Tensor): + new_output += (output.to(lm_head_device),) + else: + new_output += (output,) + return new_output + + self.register_forward_hook(set_device_hook) + self.is_sequential_parallel = True + + def state_dict(self, *args, **kwargs): + r""" + Returns the state dictionary of the model. We add the state dictionary of the value head + to the state dictionary of the wrapped model by prepending the key with `v_head.`. + """ + if not self.is_peft_model: + pretrained_model_state_dict = self.pretrained_model.state_dict(*args, **kwargs) + else: + # if it is a peft model, only save the v_head + pretrained_model_state_dict = {} + + v_head_state_dict = self.v_head.state_dict(*args, **kwargs) + for k, v in v_head_state_dict.items(): + pretrained_model_state_dict[f"v_head.{k}"] = v + return pretrained_model_state_dict + + def push_to_hub(self, *args, **kwargs): + self.pretrained_model.v_head = self.v_head + + return self.pretrained_model.push_to_hub(*args, **kwargs) + + def _init_weights(self, **kwargs): + r""" + We initialize the weights of the value head. + """ + initializer_range = kwargs.pop("v_head_initializer_range", 0.2) + # random init by default + init_strategy = kwargs.pop("v_head_init_strategy", None) + if init_strategy is None: + # do nothing + pass + elif init_strategy == "normal": + self.v_head.summary.weight.data.normal_(mean=0.0, std=initializer_range) + self.v_head.summary.bias.data.zero_() + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + return_past_key_values=False, + **kwargs, + ): + kwargs["past_key_values"] = past_key_values + if self.is_peft_model and self.pretrained_model.active_peft_config.peft_type == "PREFIX_TUNING": + kwargs.pop("past_key_values") + + base_model_output = self.pretrained_model( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, # We force the model to output hidden states + **kwargs, + ) + + last_hidden_state = base_model_output.decoder_hidden_states[-1] + lm_logits = base_model_output.logits + loss = base_model_output.loss + + value = self.v_head(last_hidden_state).squeeze(-1) + + # force upcast in fp32 if logits are in half-precision + if lm_logits.dtype != torch.float32: + lm_logits = lm_logits.float() + + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) + + def generate(self, *args, **kwargs): + r""" + We call `generate` on the wrapped model. + """ + return self.pretrained_model.generate(*args, **kwargs) diff --git a/trl/models/sd_utils.py b/trl/models/sd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05501d54aacec375dcfef7b276fb2a24542267e9 --- /dev/null +++ b/trl/models/sd_utils.py @@ -0,0 +1,150 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +State dict utilities: utility methods for converting state dicts easily +File copied from diffusers to avoid import issues and make TRL compatible +with most of diffusers versions. +""" + +import enum + + +class StateDictType(enum.Enum): + """ + The mode to use when converting state dicts. + """ + + DIFFUSERS_OLD = "diffusers_old" + PEFT = "peft" + + +PEFT_TO_DIFFUSERS = { + ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", + ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", + ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", + ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", + ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", + ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", + ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", + ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", + "to_k.lora_A": "to_k.lora.down", + "to_k.lora_B": "to_k.lora.up", + "to_q.lora_A": "to_q.lora.down", + "to_q.lora_B": "to_q.lora.up", + "to_v.lora_A": "to_v.lora.down", + "to_v.lora_B": "to_v.lora.up", + "to_out.0.lora_A": "to_out.0.lora.down", + "to_out.0.lora_B": "to_out.0.lora.up", +} + +DIFFUSERS_OLD_TO_DIFFUSERS = { + ".to_q_lora.up": ".q_proj.lora_linear_layer.up", + ".to_q_lora.down": ".q_proj.lora_linear_layer.down", + ".to_k_lora.up": ".k_proj.lora_linear_layer.up", + ".to_k_lora.down": ".k_proj.lora_linear_layer.down", + ".to_v_lora.up": ".v_proj.lora_linear_layer.up", + ".to_v_lora.down": ".v_proj.lora_linear_layer.down", + ".to_out_lora.up": ".out_proj.lora_linear_layer.up", + ".to_out_lora.down": ".out_proj.lora_linear_layer.down", +} + +DIFFUSERS_STATE_DICT_MAPPINGS = { + StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, + StateDictType.PEFT: PEFT_TO_DIFFUSERS, +} + +KEYS_TO_ALWAYS_REPLACE = { + ".processor.": ".", +} + + +def convert_state_dict(state_dict, mapping): + r""" + Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + mapping (`dict[str, str]`): + The mapping to use for conversion, the mapping should be a dictionary with the following structure: + - key: the pattern to replace + - value: the pattern to replace with + + Returns: + converted_state_dict (`dict`) + The converted state dict. + """ + converted_state_dict = {} + for k, v in state_dict.items(): + # First, filter out the keys that we always want to replace + for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): + if pattern in k: + new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] + k = k.replace(pattern, new_pattern) + + for pattern in mapping.keys(): + if pattern in k: + new_pattern = mapping[pattern] + k = k.replace(pattern, new_pattern) + break + converted_state_dict[k] = v + return converted_state_dict + + +def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): + r""" + Converts a state dict to new diffusers format. The state dict can be from previous diffusers format + (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will + return the state dict as is. + + The method only supports the conversion from diffusers old, PEFT to diffusers new for now. + + Args: + state_dict (`dict[str, torch.Tensor]`): + The state dict to convert. + original_type (`StateDictType`, *optional*): + The original type of the state dict, if not provided, the method will try to infer it automatically. + kwargs (`dict`, *args*): + Additional arguments to pass to the method. + + - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended + with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in + `get_peft_model_state_dict` method: + https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 + but we add it here in case we don't want to rely on that method. + """ + peft_adapter_name = kwargs.pop("adapter_name", None) + if peft_adapter_name is not None: + peft_adapter_name = "." + peft_adapter_name + else: + peft_adapter_name = "" + + if original_type is None: + # Old diffusers to PEFT + if any("to_out_lora" in k for k in state_dict.keys()): + original_type = StateDictType.DIFFUSERS_OLD + elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): + original_type = StateDictType.PEFT + elif any("lora_linear_layer" in k for k in state_dict.keys()): + # nothing to do + return state_dict + else: + raise ValueError("Could not automatically infer state dict type") + + if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): + raise ValueError(f"Original type {original_type} is not supported") + + mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] + return convert_state_dict(state_dict, mapping) diff --git a/trl/models/utils.py b/trl/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..251269be1de20167b5474f43b1eaaec3ad3b7d78 --- /dev/null +++ b/trl/models/utils.py @@ -0,0 +1,341 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +import torch.nn as nn +from packaging import version +from transformers import PreTrainedModel, PreTrainedTokenizer + +from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + +if TYPE_CHECKING: + from accelerate import Accelerator + from deepspeed.runtime.engine import DeepSpeedEngine + from torch.nn import Module + from torch.nn.parallel.distributed import DistributedDataParallel + + +# TODO: Add Abstract Base Class if more formats are added +@dataclass +class ChatMlSpecialTokens: + """Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" + + bos_token: str = "<|im_start|>" + eos_token: str = "<|im_end|>" + pad_token: str = "<|im_end|>" + + @property + def system(self): + return f"{self.bos_token}system" + + @property + def user(self): + return f"{self.bos_token}user" + + @property + def assistant(self): + return f"{self.bos_token}assistant" + + @property + def chat_template(self): + return ( + "{% for message in messages %}" + f"{{{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + '{self.eos_token}' + '\n'}}}}" + "{% endfor %}" + "{% if add_generation_prompt %}" + f"{{{{ '{self.assistant}\n' }}}}" + "{% endif %}" + ) + + +FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} + + +def setup_chat_format( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + format: Optional[Literal["chatml"]] = "chatml", + resize_to_multiple_of: Optional[int] = None, +) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + """ + Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. + + If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`. + + Args: + model (`~transformers.PreTrainedModel`): The model to be modified. + tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. + format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". + resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None. + + Returns: + model (`~transformers.PreTrainedModel`): The modified model. + tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. + """ + # check if model already had a chat template + if tokenizer.chat_template is not None: + raise ValueError( + "Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None" + ) + + # check if format available and retrieve + if format not in FORMAT_MAPPING: + raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") + + chat_format = FORMAT_MAPPING[format]() + + # set special tokens and them + tokenizer.eos_token = chat_format.eos_token + tokenizer.pad_token = chat_format.pad_token + tokenizer.bos_token = chat_format.bos_token + tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) + # set chat format for tokenizer + tokenizer.chat_template = chat_format.chat_template + + # resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 + model.resize_token_embeddings( + len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None + ) + # Update the model config to use the new eos & bos tokens + if getattr(model, "config", None) is not None: + model.config.pad_token_id = tokenizer.pad_token_id + model.config.bos_token_id = tokenizer.bos_token_id + model.config.eos_token_id = tokenizer.eos_token_id + # Update the generation config to use the new eos & bos token + if getattr(model, "generation_config", None) is not None: + model.generation_config.bos_token_id = tokenizer.bos_token_id + model.generation_config.eos_token_id = tokenizer.eos_token_id + model.generation_config.pad_token_id = tokenizer.pad_token_id + + return model, tokenizer + + +def remove_hooks(model: "DeepSpeedEngine") -> None: + """Removes the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + else: + raise RuntimeError("The model optimizer is None, which is not yet supported.") + + for param in iter_params(optimizer_offload.module, recurse=True): + param.ds_active_sub_modules.clear() + + for hook in optimizer_offload.forward_hooks: + hook.remove() + for hook in optimizer_offload.backward_hooks: + hook.remove() + + optimizer_offload.forward_hooks = [] + optimizer_offload.backward_hooks = [] + + +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) + + +def iter_params(module, recurse=False): + return [param for _, param in get_all_parameters(module, recurse)] + + +def add_hooks(model: "DeepSpeedEngine") -> None: + """Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" + import deepspeed + + if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer + return + if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"): + optimizer_offload = model.optimizer.parameter_offload + elif model.optimizer is not None: + optimizer_offload = model.optimizer + else: + raise RuntimeError("The model optimizer is None, which is not yet supported.") + if version.parse(deepspeed.__version__) >= version.parse("0.16.4"): + # Account for renaming in https://github.com/deepspeedai/DeepSpeed/pull/6847 + optimizer_offload._register_deepspeed_module(optimizer_offload.module) + else: + optimizer_offload._register_hooks_recursively(optimizer_offload.module) + + +@contextmanager +def unwrap_model_for_generation( + model: Union["DistributedDataParallel", "DeepSpeedEngine"], + accelerator: "Accelerator", + gather_deepspeed3_params: bool = True, +): + """ + Context manager to unwrap distributed or accelerated models for generation tasks. + + Args: + model (`Union[DistributedDataParallel, DeepSpeedEngine]`): + Model to be unwrapped. + accelerator (`~accelerate.Accelerator`): + Accelerator instance managing the model. + gather_deepspeed3_params (`bool`, *optional*, defaults to `True`): + Whether to gather weights for DeepSpeed ZeRO Stage 3 models. If `False`, skips parameter gathering, which + can be more memory-efficient but may lead to slower generation times. + + Yields: + Unwrapped model. + + Example: + ```python + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + generated_outputs = unwrapped_model.generate(input_ids) + ``` + """ + unwrapped_model = accelerator.unwrap_model(model) + if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3: + if not gather_deepspeed3_params: + yield accelerator.unwrap_model(model) + else: + import deepspeed + + with deepspeed.zero.GatheredParameters(model.parameters()): + remove_hooks(model) + yield accelerator.unwrap_model(model) + add_hooks(model) + else: + yield unwrapped_model + + +def prepare_deepspeed(model: "Module", accelerator: "Accelerator"): + """Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration. + + Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + """ + import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252 + + deepspeed_plugin = accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + stage = config_kwargs["zero_optimization"]["stage"] + + if model is not None: + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and stage == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache + # @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + # If ZeRO-3 is used, we shard both the active and reference model. + # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO + # disabled (stage 0) + if stage != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + +def prepare_fsdp(model, accelerator): + # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1421 + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, + # don't wrap it again + if not isinstance(model, FSDP): + accelerator.state.fsdp_plugin.set_auto_wrap_policy(model) + fsdp_plugin = accelerator.state.fsdp_plugin + kwargs = { + "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, + "cpu_offload": fsdp_plugin.cpu_offload, + "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, + "mixed_precision": fsdp_plugin.mixed_precision_policy, + "sync_module_states": fsdp_plugin.sync_module_states, + "backward_prefetch": fsdp_plugin.backward_prefetch, + "forward_prefetch": fsdp_plugin.forward_prefetch, + "use_orig_params": fsdp_plugin.use_orig_params, + "param_init_fn": fsdp_plugin.param_init_fn, + "ignored_modules": fsdp_plugin.ignored_modules, + "limit_all_gathers": fsdp_plugin.limit_all_gathers, + "device_id": accelerator.device, + } + model = FSDP(model, **kwargs) + model.eval() + return model + + +class _ForwardRedirection: + """Implements the `forward-redirection`. + + Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602 + + A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead. + + """ + + def __call__( + self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any + ): + """Reroutes a method call through the `wrapper_module`'s `forward` method. + + Args: + wrapper_module: The module that has `original_module` wrapped. + original_module: The module that was wrapped inside `wrapper_module`. + method_name: The name of the method that should be called on the `original_module` after inputs get + redirected through the `wrapper_module`'s `forward` method. + *args: The positional arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + **kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched + `forward` method instead. + + """ + original_forward = original_module.forward + + def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any: + # Unpatch ourselves immediately before calling the method `method_name` + # because itself may want to call the real `forward` + original_module.forward = original_forward # type: ignore[method-assign] + # Call the actual method e.g. `.training_step(...)` + out = method(*_args, **_kwargs) + self.on_after_inner_forward(wrapper_module, original_module) + return out + + # Patch the original_module's forward so we can redirect the arguments back to the real method + original_module.forward = wrapped_forward # type: ignore[method-assign] + + wrapper_output = wrapper_module(*args, **kwargs) + self.on_after_outer_forward(wrapper_module, original_module) + return wrapper_output + + def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass + + def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None: + pass diff --git a/trl/rewards/__init__.py b/trl/rewards/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..136d70af9cb5d95a855ce5bcedf5b5f9b526cb69 --- /dev/null +++ b/trl/rewards/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "format_rewards": ["think_format_reward"], +} + + +if TYPE_CHECKING: + from .format_rewards import think_format_reward + + +else: + sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) diff --git a/trl/rewards/format_rewards.py b/trl/rewards/format_rewards.py new file mode 100644 index 0000000000000000000000000000000000000000..b27e342c47a01b94b19e48a21832c8f48d8da24f --- /dev/null +++ b/trl/rewards/format_rewards.py @@ -0,0 +1,49 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: + r""" + Reward function that checks if the reasoning process is enclosed within `""` and `""` tags. The + function returns a reward of 1.0 if the format is correct, otherwise 0.0. + + Args: + completions (`list[list[dict[str, str]]]`): + List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary + containing the key `"content"` with the value being the text of the completion. + **kwargs: + Additional keyword arguments. This function does not use them, but they are required in the function + signature to ensure compatibility with trainers like [`GRPOTrainer`]. + + Returns: + `list[float]`: + A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0. + + Example: + ```python + >>> from trl.rewards import think_format_reward + >>> completions = [ + ... [{"content": "\nThis is my reasoning.\n\nThis is my answer."}], + ... [{"content": "\nThis is my reasoning.\nThis is my answer."}], + ... ] + >>> think_format_reward(completions) + [1.0, 0.0] + ``` + """ + pattern = r"^(?!.*)(.*?).*$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] diff --git a/trl/scripts/__init__.py b/trl/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..272a717db0e9a01d8c9ff23c21d44d6e804bc99c --- /dev/null +++ b/trl/scripts/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], +} + +if TYPE_CHECKING: + from .utils import ScriptArguments, TrlParser, init_zero_verbose +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/scripts/dpo.py b/trl/scripts/dpo.py new file mode 100644 index 0000000000000000000000000000000000000000..5f43ac88ba57e4755148f7690f9df75af0b92b88 --- /dev/null +++ b/trl/scripts/dpo.py @@ -0,0 +1,159 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --dataset_streaming \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-7 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns + --report_to wandb + +# LoRA: +python trl/scripts/dpo.py \ + --dataset_name trl-lib/ultrafeedback_binarized \ + --dataset_streaming \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 50 \ + --output_dir Qwen2-0.5B-DPO \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 + --report_to wandb +""" + +import argparse + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + DPOConfig, + DPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + + +def main(script_args, training_args, model_args): + ################ + # Model & Tokenizer + ################### + torch_dtype = ( + model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) + ) + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + peft_config = get_peft_config(model_args) + if peft_config is None: + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + else: + ref_model = None + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + if script_args.ignore_bias_buffers: + # torch distributed hack + model._ddp_params_and_buffers_to_ignore = [ + name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool + ] + + ################ + # Dataset + ################ + dataset = load_dataset( + script_args.dataset_name, + name=script_args.dataset_config, + streaming=script_args.dataset_streaming, + ) + + ########## + # Training + ################ + trainer = DPOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=peft_config, + ) + + trainer.train() + + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/env.py b/trl/scripts/env.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb81b4544e352065a7c45a1a41c5587c809a0a4 --- /dev/null +++ b/trl/scripts/env.py @@ -0,0 +1,84 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import platform +from importlib.metadata import version + +import torch +from accelerate.commands.config import default_config_file, load_config_from_file +from transformers import is_bitsandbytes_available +from transformers.utils import is_openai_available, is_peft_available + +from .. import __version__ +from ..import_utils import ( + is_deepspeed_available, + is_diffusers_available, + is_liger_kernel_available, + is_llm_blender_available, + is_vllm_available, +) +from .utils import get_git_commit_hash + + +def print_env(): + devices = None + if torch.cuda.is_available(): + devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] + elif torch.backends.mps.is_available(): + devices = ["MPS"] + elif torch.xpu.is_available(): + devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())] + + accelerate_config = accelerate_config_str = "not found" + + # Get the default from the config file. + if os.path.isfile(default_config_file): + accelerate_config = load_config_from_file(default_config_file).to_dict() + + accelerate_config_str = ( + "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) + if isinstance(accelerate_config, dict) + else accelerate_config + ) + + commit_hash = get_git_commit_hash("trl") + + info = { + "Platform": platform.platform(), + "Python version": platform.python_version(), + "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, + "PyTorch version": version("torch"), + "accelerator(s)": ", ".join(devices) if devices is not None else "cpu", + "Transformers version": version("transformers"), + "Accelerate version": version("accelerate"), + "Accelerate config": accelerate_config_str, + "Datasets version": version("datasets"), + "HF Hub version": version("huggingface_hub"), + "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", + "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", + "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed", + "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", + "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", + "OpenAI version": version("openai") if is_openai_available() else "not installed", + "PEFT version": version("peft") if is_peft_available() else "not installed", + "vLLM version": version("vllm") if is_vllm_available() else "not installed", + } + + info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) + print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa + + +if __name__ == "__main__": + print_env() diff --git a/trl/scripts/grpo.py b/trl/scripts/grpo.py new file mode 100644 index 0000000000000000000000000000000000000000..180671e6b4f26d9ee85adc03457852ff364539e1 --- /dev/null +++ b/trl/scripts/grpo.py @@ -0,0 +1,132 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer + +from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config +from trl.rewards import think_format_reward + + +reward_funcs_registry = { + "think_format_reward": think_format_reward, +} + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """ + Script arguments for the GRPO training script. + + Args: + reward_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): + Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a + directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. + reward_funcs (`list[str]` or `None`, *optional*, defaults to `None`): + Reward functions to use. It can be either one of `"think_format_reward"`; or a dotted import path " + (e.g., `'my_lib.rewards.custom_reward'`). + """ + + reward_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." + }, + ) + reward_funcs: Optional[list[str]] = field( + default=None, + metadata={ + "help": "Reward functions to use. It can be either one of 'think_format_reward'; or a dotted " + "import path. (e.g., 'my_lib.rewards.custom_reward')." + }, + ) + + +def main(script_args, training_args, model_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + # Get the reward models and functions + reward_funcs = [] + if script_args.reward_model_name_or_path: + reward_model = AutoModelForSequenceClassification.from_pretrained( + script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 + ) + reward_funcs.append(reward_model) + + if script_args.reward_funcs: + for func_name in script_args.reward_funcs: + if func_name in reward_funcs_registry: + reward_funcs.append(reward_funcs_registry[func_name]) + elif "." in func_name: + module_path, func_name = func_name.rsplit(".", 1) + sys.path.insert(0, os.getcwd()) + module = importlib.import_module(module_path) + reward_func = getattr(module, func_name) + reward_funcs.append(reward_func) + else: + raise ValueError( + f"Could not load reward function '{func_name}'. Expected one of " + f"{list(reward_funcs_registry.keys())} or a valid import path." + ) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Initialize the GRPO trainer + trainer = GRPOTrainer( + model=model, + reward_funcs=reward_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/kto.py b/trl/scripts/kto.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef7e4c7f73127401c951c5f6ee988b31762097e --- /dev/null +++ b/trl/scripts/kto.py @@ -0,0 +1,128 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. + +# Full training: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 16 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step + +# QLoRA: +python trl/scripts/kto.py \ + --dataset_name trl-lib/kto-mix-14k \ + --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ + --per_device_train_batch_size 8 \ + --num_train_epochs 1 \ + --learning_rate 5e-7 \ + --lr_scheduler_type=cosine \ + --gradient_accumulation_steps 1 \ + --logging_steps 10 \ + --eval_steps 500 \ + --output_dir=kto-aligned-model-lora \ + --warmup_ratio 0.1 \ + --report_to wandb \ + --bf16 \ + --logging_first_step \ + --use_peft \ + --load_in_4bit \ + --lora_target_modules=all-linear \ + --lora_r=16 \ + --lora_alpha=16 +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl import ( + KTOConfig, + KTOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, + get_peft_config, + setup_chat_format, +) + + +def main(script_args, training_args, model_args): + # Load a pretrained model + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # If we are aligning a base model, we use ChatML as the default template + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer) + + # Load the dataset + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Initialize the KTO trainer + trainer = KTOTrainer( + model, + ref_model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # Train and push the model to the Hub + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py new file mode 100644 index 0000000000000000000000000000000000000000..923829f77528c9c395ed1d7e0ff70d429c534ddc --- /dev/null +++ b/trl/scripts/sft.py @@ -0,0 +1,149 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +# Full training +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --eos_token '<|im_end|>' \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub + +# LoRA +python trl/scripts/sft.py \ + --model_name_or_path Qwen/Qwen2-0.5B \ + --dataset_name trl-lib/Capybara \ + --learning_rate 2.0e-4 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --eos_token '<|im_end|>' \ + --logging_steps 25 \ + --eval_strategy steps \ + --eval_steps 100 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --output_dir Qwen2-0.5B-SFT \ + --push_to_hub +""" + +import argparse + +from datasets import load_dataset +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + +from trl import ( + ModelConfig, + ScriptArguments, + SFTConfig, + SFTTrainer, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, + setup_chat_format, +) + + +def main(script_args, training_args, model_args): + ################ + # Model init kwargs & Tokenizer + ################ + quantization_config = get_quantization_config(model_args) + model_kwargs = dict( + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + attn_implementation=model_args.attn_implementation, + torch_dtype=model_args.torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + # Create model + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values() + + if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures): + from transformers import AutoModelForImageTextToText + + model_kwargs.pop("use_cache", None) # Image models do not support cache + model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) + + # Create tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True + ) + + # Set default chat template if needed + if tokenizer.chat_template is None: + model, tokenizer = setup_chat_format(model, tokenizer, format="chatml") + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + ################ + # Training + ################ + trainer = SFTTrainer( + model=model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + trainer.train() + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == "__main__": + parser = make_parser() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True) + main(script_args, training_args, model_args) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4e335f20a861c8c03d4d177a56ea6b2c7ea2333 --- /dev/null +++ b/trl/scripts/utils.py @@ -0,0 +1,282 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import inspect +import logging +import os +import subprocess +import sys +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Optional, Union + +import yaml +from transformers import HfArgumentParser +from transformers.hf_argparser import DataClass, DataClassType +from transformers.utils import is_rich_available + + +logger = logging.getLogger(__name__) + + +@dataclass +class ScriptArguments: + """ + Arguments common to all scripts. + + Args: + dataset_name (`str`): + Dataset name. + dataset_config (`str` or `None`, *optional*, defaults to `None`): + Dataset configuration name. Corresponds to the `name` argument of the [`~datasets.load_dataset`] function. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + dataset_streaming (`bool`, *optional*, defaults to `False`): + Whether to stream the dataset. If True, the dataset will be loaded in streaming mode. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar + type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + """ + + dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."}) + dataset_config: Optional[str] = field( + default=None, + metadata={ + "help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` " + "function." + }, + ) + dataset_train_split: str = field(default="train", metadata={"help": "Dataset split to use for training."}) + dataset_test_split: str = field(default="test", metadata={"help": "Dataset split to use for evaluation."}) + dataset_streaming: bool = field( + default=False, + metadata={"help": "Whether to stream the dataset. If True, the dataset will be loaded in streaming mode."}, + ) + gradient_checkpointing_use_reentrant: bool = field( + default=False, + metadata={"help": "Whether to apply `use_reentrant` for gradient checkpointing."}, + ) + ignore_bias_buffers: bool = field( + default=False, + metadata={ + "help": "Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid " + "scalar type, inplace operation. See " + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992." + }, + ) + + +def init_zero_verbose(): + """ + Perform zero verbose init - use this method on top of the CLI modules to make + logging and warning output cleaner. Uses Rich if available, falls back otherwise. + """ + import logging + import warnings + + FORMAT = "%(message)s" + + if is_rich_available(): + from rich.logging import RichHandler + + handler = RichHandler() + else: + handler = logging.StreamHandler() + + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[handler], level=logging.ERROR) + + # Custom warning handler to redirect warnings to the logging system + def warning_handler(message, category, filename, lineno, file=None, line=None): + logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}") + + # Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well + warnings.showwarning = warning_handler + + +class TrlParser(HfArgumentParser): + """ + A subclass of [`transformers.HfArgumentParser`] designed for parsing command-line arguments with dataclass-backed + configurations, while also supporting configuration file loading and environment variable management. + + Args: + dataclass_types (`Union[DataClassType, Iterable[DataClassType]]` or `None`, *optional*, defaults to `None`): + Dataclass types to use for argument parsing. + **kwargs: + Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor. + + Examples: + + ```yaml + # config.yaml + env: + VAR1: value1 + arg1: 23 + ``` + + ```python + # main.py + import os + from dataclasses import dataclass + from trl import TrlParser + + @dataclass + class MyArguments: + arg1: int + arg2: str = "alpha" + + parser = TrlParser(dataclass_types=[MyArguments]) + training_args = parser.parse_args_and_config() + + print(training_args, os.environ.get("VAR1")) + ``` + + ```bash + $ python main.py --config config.yaml + (MyArguments(arg1=23, arg2='alpha'),) value1 + + $ python main.py --arg1 5 --arg2 beta + (MyArguments(arg1=5, arg2='beta'),) None + ``` + """ + + def __init__( + self, + dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None, + **kwargs, + ): + # Make sure dataclass_types is an iterable + if dataclass_types is None: + dataclass_types = [] + elif not isinstance(dataclass_types, Iterable): + dataclass_types = [dataclass_types] + + # Check that none of the dataclasses have the "config" field + for dataclass_type in dataclass_types: + if "config" in dataclass_type.__dataclass_fields__: + raise ValueError( + f"Dataclass {dataclass_type.__name__} has a field named 'config'. This field is reserved for the " + f"config file path and should not be used in the dataclass." + ) + + super().__init__(dataclass_types=dataclass_types, **kwargs) + + def parse_args_and_config( + self, + args: Optional[Iterable[str]] = None, + return_remaining_strings: bool = False, + fail_with_unknown_args: bool = True, + ) -> tuple[DataClass, ...]: + """ + Parse command-line args and config file into instances of the specified dataclass types. + + This method wraps [`transformers.HfArgumentParser.parse_args_into_dataclasses`] and also parses the config file + specified with the `--config` flag. The config file (in YAML format) provides argument values that replace the + default values in the dataclasses. Command line arguments can override values set by the config file. The + method also sets any environment variables specified in the `env` field of the config file. + """ + args = list(args) if args is not None else sys.argv[1:] + if "--config" in args: + # Get the config file path from + config_index = args.index("--config") + args.pop(config_index) # remove the --config flag + config_path = args.pop(config_index) # get the path to the config file + with open(config_path) as yaml_file: + config = yaml.safe_load(yaml_file) + + # Set the environment variables specified in the config file + if "env" in config: + env_vars = config.pop("env", {}) + if not isinstance(env_vars, dict): + raise ValueError("`env` field should be a dict in the YAML file.") + for key, value in env_vars.items(): + os.environ[key] = str(value) + + # Set the defaults from the config values + config_remaining_strings = self.set_defaults_with_config(**config) + else: + config_remaining_strings = [] + + # Parse the arguments from the command line + output = self.parse_args_into_dataclasses(args=args, return_remaining_strings=return_remaining_strings) + + # Merge remaining strings from the config file with the remaining strings from the command line + if return_remaining_strings: + args_remaining_strings = output[-1] + return output[:-1] + (config_remaining_strings + args_remaining_strings,) + elif fail_with_unknown_args and config_remaining_strings: + raise ValueError( + f"Unknown arguments from config file: {config_remaining_strings}. Please remove them, add them to the " + "dataclass, or set `fail_with_unknown_args=False`." + ) + else: + return output + + def set_defaults_with_config(self, **kwargs) -> list[str]: + """ + Overrides the parser's default values with those provided via keyword arguments, including for subparsers. + + Any argument with an updated default will also be marked as not required + if it was previously required. + + Returns a list of strings that were not consumed by the parser. + """ + + def apply_defaults(parser, kw): + used_keys = set() + for action in parser._actions: + # Handle subparsers recursively + if isinstance(action, argparse._SubParsersAction): + for subparser in action.choices.values(): + used_keys.update(apply_defaults(subparser, kw)) + elif action.dest in kw: + action.default = kw[action.dest] + action.required = False + used_keys.add(action.dest) + return used_keys + + used_keys = apply_defaults(self, kwargs) + # Remaining args not consumed by the parser + remaining = [ + item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value)) + ] + return remaining + + +def get_git_commit_hash(package_name): + try: + # Import the package to locate its path + package = importlib.import_module(package_name) + # Get the path to the package using inspect + package_path = os.path.dirname(inspect.getfile(package)) + + # Navigate up to the Git repository root if the package is inside a subdirectory + git_repo_path = os.path.abspath(os.path.join(package_path, "..")) + git_dir = os.path.join(git_repo_path, ".git") + + if os.path.isdir(git_dir): + # Run the git command to get the current commit hash + commit_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=git_repo_path).strip().decode("utf-8") + ) + return commit_hash + else: + return None + except Exception as e: + return f"Error: {str(e)}" diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a0dc99fe4ae847ab8c3b3687e6da4cc02be88c --- /dev/null +++ b/trl/scripts/vllm_serve.py @@ -0,0 +1,584 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from collections.abc import Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from itertools import chain +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection +from typing import Optional + +import torch + +from trl import TrlParser +from trl.import_utils import ( + is_fastapi_available, + is_pydantic_available, + is_uvicorn_available, + is_vllm_ascend_available, + is_vllm_available, +) + + +if is_fastapi_available(): + from fastapi import FastAPI + + +if is_pydantic_available(): + from pydantic import BaseModel + + +if is_uvicorn_available(): + import uvicorn + + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.parallel_state import get_world_group + from vllm.distributed.utils import StatelessProcessGroup + from vllm.sampling_params import GuidedDecodingParams + from vllm.utils import get_open_port + + if is_vllm_ascend_available(): + from vllm_ascend.distributed.device_communicators.pyhccl import PyHcclCommunicator as PyNcclCommunicator + + +logger = logging.getLogger(__name__) + +# We use CUDA with multiprocessing, so we must use the 'spawn' start method. Otherwise, we will get the following +# error: RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use +# the 'spawn' start method +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +class WeightSyncWorkerExtension: + """ + A vLLM worker extension that enables weight synchronization between a client and multiple server workers. + + This worker uses a `StatelessProcessGroup` to establish communication and a `PyNcclCommunicator` to handle + efficient GPU-based communication using NCCL. The primary purpose of this class is to receive updated model weights + from a client process and distribute them to all worker processes participating in model inference. + """ + + # The following attributes are initialized when `init_communicator` method is called. + pynccl_comm = None # Communicator for weight updates + client_rank = None # Source rank for broadcasting updated weights + + def init_communicator(self, host: str, port: int, world_size: int) -> None: + """ + Initializes the weight update communicator using a stateless process group. + + This method creates a `StatelessProcessGroup` that allows external training processes to + communicate with vLLM workers without interfering with the global torch distributed group. + + Args: + host (`str`): + Hostname or IP address of the master node. + port (`int`): + Port number to be used for communication. + world_size (`int`): + Total number of participating processes in the update group. + """ + if self.pynccl_comm is not None: + raise RuntimeError("Weight update group already initialized. Call close_communicator first.") + + # Get the rank of the current worker in the global world group. + rank = get_world_group().rank + + # Create a stateless process group to manage communication between training processes and vLLM workers. + pg = StatelessProcessGroup.create(host=host, port=port, rank=rank, world_size=world_size) + + # Initialize the NCCL-based communicator for weight synchronization. + self.pynccl_comm = PyNcclCommunicator(pg, device=self.device) + + # The client process that sends updated weights has the highest rank (world_size - 1). + self.client_rank = world_size - 1 + + def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None: + """ + Receives updated weights from the client process and updates the named parameter in the model. + + Args: + name (`str`): + Name of the weight tensor being updated. + dtype (`torch.dtype`): + Data type of the weight tensor (e.g., `torch.float32`). + shape (`Sequence[int]`): + Shape of the weight tensor. + """ + if self.pynccl_comm is None: + raise RuntimeError("Communicator not initialized. Call `init_communicator` first.") + + # Allocate memory for the incoming weight tensor on the correct device. + weight = torch.empty(shape, dtype=dtype, device=self.device) + + # Use NCCL to broadcast the updated weights from the client (src) to all workers. + self.pynccl_comm.broadcast(weight, src=self.client_rank) + self.pynccl_comm.group.barrier() + + # Load the received weights into the model. + self.model_runner.model.load_weights(weights=[(name, weight)]) + + def close_communicator(self) -> None: + """ + Closes the communicator when weight synchronization is no longer needed. + + This method deletes the NCCL communicator to release associated resources. + """ + + if self.pynccl_comm is not None: + del self.pynccl_comm + self.pynccl_comm = None # Ensure attribute is reset to None + self.client_rank = None # Ensure attribute is reset to None + + +@dataclass +class ScriptArguments: + r""" + Arguments for the script. + + Args: + model (`str`): + Model name or path to load the model from. + revision (`str` or `None`, *optional*, defaults to `None`): + Revision to use for the model. If not specified, the default branch will be used. + tensor_parallel_size (`int`, *optional*, defaults to `1`): + Number of tensor parallel workers to use. + data_parallel_size (`int`, *optional*, defaults to `1`): + Number of data parallel workers to use. + host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host address to run the server on. + port (`int`, *optional*, defaults to `8000`): + Port to run the server on. + gpu_memory_utilization (`float`, *optional*, defaults to `0.9`): + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the + device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus + improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors + during initialization. + dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined + based on the model configuration. Find the supported values in the vLLM documentation. + max_model_len (`int` or `None`, *optional*, defaults to `None`): + If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced + `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model + context size, which might be much larger than the KV cache, leading to inefficiencies. + enable_prefix_caching (`bool` or `None`, *optional*, defaults to `None`): + Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support + this feature. + enforce_eager (`bool` or `None`, *optional*, defaults to `None`): + Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the + model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. + kv_cache_dtype (`str`, *optional*, defaults to `"auto"`): + Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type. + log_level (`str`, *optional*, defaults to `"info"`): + Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`, + `"trace"`. + """ + + model: str = field( + metadata={"help": "Model name or path to load the model from."}, + ) + revision: Optional[str] = field( + default=None, + metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."}, + ) + tensor_parallel_size: int = field( + default=1, + metadata={"help": "Number of tensor parallel workers to use."}, + ) + data_parallel_size: int = field( + default=1, + metadata={"help": "Number of data parallel workers to use."}, + ) + host: str = field( + default="0.0.0.0", + metadata={"help": "Host address to run the server on."}, + ) + port: int = field( + default=8000, + metadata={"help": "Port to run the server on."}, + ) + gpu_memory_utilization: float = field( + default=0.9, + metadata={ + "help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV " + "cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache " + "size and thus improve the model's throughput. However, if the value is too high, it may cause " + "out-of-memory (OOM) errors during initialization." + }, + ) + dtype: str = field( + default="auto", + metadata={ + "help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically " + "determined based on the model configuration. Find the supported values in the vLLM documentation." + }, + ) + max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced " + "`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model " + "context size, which might be much larger than the KV cache, leading to inefficiencies." + }, + ) + enable_prefix_caching: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the " + "hardware support this feature." + }, + ) + enforce_eager: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always " + "execute the model in eager mode. If `False` (default behavior), we will use CUDA graph and eager " + "execution in hybrid." + }, + ) + kv_cache_dtype: str = field( + default="auto", + metadata={ + "help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type." + }, + ) + log_level: str = field( + default="info", + metadata={ + "help": "Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', " + "'trace'." + }, + ) + + +def llm_worker( + script_args: ScriptArguments, data_parallel_rank: int, master_port: int, connection: Connection +) -> None: + # Set required environment variables for DP to work with vLLM + os.environ["VLLM_DP_RANK"] = str(data_parallel_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(data_parallel_rank) + os.environ["VLLM_DP_SIZE"] = str(script_args.data_parallel_size) + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + + llm = LLM( + model=script_args.model, + revision=script_args.revision, + tensor_parallel_size=script_args.tensor_parallel_size, + gpu_memory_utilization=script_args.gpu_memory_utilization, + enforce_eager=script_args.enforce_eager, + dtype=script_args.dtype, + # Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can + # directly reuse the KV cache if it shares the same prefix with one of the existing queries. + # This is particularly useful here because we generate completions from the same prompts. + enable_prefix_caching=script_args.enable_prefix_caching, + kv_cache_dtype=script_args.kv_cache_dtype, + max_model_len=script_args.max_model_len, + worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension", + ) + + # Send ready signal to parent process + connection.send({"status": "ready"}) + + while True: + # Wait for commands from the parent process + try: + command = connection.recv() + except KeyboardInterrupt: + llm.collective_rpc(method="close_communicator") + break + + # Handle commands + if command["type"] in ["call", "fire_and_forget"]: + method_name = command["method"] + args, kwargs = command.get("args", ()), command.get("kwargs", {}) + method = getattr(llm, method_name) + result = method(*args, **kwargs) + if command["type"] == "call": + connection.send(result) + elif command["type"] == "shutdown": + break + + +def chunk_list(lst: list, n: int) -> list[list]: + """ + Split list `lst` into `n` evenly distributed sublists. + + Example: + >>> chunk_list([1, 2, 3, 4, 5, 6], 2) + [[1, 2, 3], [4, 5, 6]] + >>> chunk_list([1, 2, 3, 4, 5, 6], 4) + [[1, 2], [3, 4], [5], [6]] + >>> chunk_list([1, 2, 3, 4, 5, 6], 8) + [[1], [2], [3], [4], [5], [6], [], []] + """ + k, r = divmod(len(lst), n) + return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)] + + +def main(script_args: ScriptArguments): + if not is_fastapi_available(): + raise ImportError( + "FastAPI is required to run the vLLM serve script. Please install it using `pip install fastapi`." + ) + + if not is_pydantic_available(): + raise ImportError( + "Pydantic is required to run the vLLM serve script. Please install it using `pip install pydantic`." + ) + + if not is_uvicorn_available(): + raise ImportError( + "Uvicorn is required to run the vLLM serve script. Please install it using `pip install uvicorn`." + ) + + if not is_vllm_available(): + raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.") + + # Spawn dp workers, and setup pipes for communication + master_port = get_open_port() + connections = [] + processes = [] + for data_parallel_rank in range(script_args.data_parallel_size): + parent_connection, child_connection = Pipe() + process = Process(target=llm_worker, args=(script_args, data_parallel_rank, master_port, child_connection)) + process.start() + connections.append(parent_connection) + processes.append(process) + + @asynccontextmanager + async def lifespan(app: FastAPI): + # Wait for all workers to send "ready" + ready_connections = set() + while len(ready_connections) < script_args.data_parallel_size: + for connection in connections: + msg = connection.recv() + if isinstance(msg, dict) and msg.get("status") == "ready": + ready_connections.add(connection) + + yield + + # Wait for processes to terminate + for process in processes: + process.join(timeout=10) # Wait for 10 seconds for the process to terminate + if process.is_alive(): + logger.warning(f"Process {process} is still alive after 10 seconds, attempting to terminate...") + process.terminate() + process.join() # ensure process termination after calling terminate() + + app = FastAPI(lifespan=lifespan) + + # Define the endpoints for the model server + @app.get("/health/") + async def health(): + """ + Health check endpoint to verify that the server is running. + """ + return {"status": "ok"} + + @app.get("/get_world_size/") + async def get_world_size(): + """ + Retrieves the world size of the LLM engine, which is `tensor_parallel_size * data_parallel_size`. + + Returns: + `dict`: + A dictionary containing the world size. + + Example response: + ```json + {"world_size": 8} + ``` + """ + return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} + + class GenerateRequest(BaseModel): + prompts: list[str] + n: int = 1 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + max_tokens: int = 16 + guided_decoding_regex: Optional[str] = None + + class GenerateResponse(BaseModel): + completion_ids: list[list[int]] + + @app.post("/generate/", response_model=GenerateResponse) + async def generate(request: GenerateRequest): + """ + Generates completions for the provided prompts. + + Args: + request (`GenerateRequest`): + - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + + Returns: + `GenerateResponse`: + - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. + + Example request: + ```json + {"prompts": ["Hello world", "What is AI?"]} + ``` + + Example response: + ```json + {"completion_ids": [[101, 102, 103], [201, 202, 203]]} + ``` + """ + + # Guided decoding, if enabled + if request.guided_decoding_regex is not None: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=request.guided_decoding_regex) + else: + guided_decoding = None + + # Sampling parameters + sampling_params = SamplingParams( + n=request.n, + repetition_penalty=request.repetition_penalty, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + min_p=request.min_p, + max_tokens=request.max_tokens, + guided_decoding=guided_decoding, + ) + # Evenly distribute prompts across DP ranks + chunked_prompts = chunk_list(request.prompts, script_args.data_parallel_size) + + # Send the prompts to each worker + for connection, prompts in zip(connections, chunked_prompts): + # When the number of prompts is less than data_parallel_size, some workers will receive empty prompts. + # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply + # with vLLM's requirement, and we later ignore the result. + if not prompts: + prompts = [""] + kwargs = {"prompts": prompts, "sampling_params": sampling_params} + connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) + + # Receive results + all_outputs = [connection.recv() for connection in connections] + + # Handle empty prompts (see above) + all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts] + + # Flatten and combine all results + all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list + completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] + return {"completion_ids": completion_ids} + + class InitCommunicatorRequest(BaseModel): + host: str + port: int + world_size: int + + @app.post("/init_communicator/") + async def init_communicator(request: InitCommunicatorRequest): + """ + Initializes the communicator for synchronizing model weights between a client and multiple server + workers. + + Args: + request (`InitCommunicatorRequest`): + - `host` (`str`): Hostname or IP address of the master node. + - `port` (`int`): Port number to be used for communication. + - `world_size` (`int`): Total number of participating processes in the group. + """ + world_size = script_args.tensor_parallel_size * script_args.data_parallel_size + 1 + + # The function init_communicator is called this way: init_communicator(host, port, world_size) + # So with collective_rpc we need to call it this way: + # llm.collective_rpc(method="init_communicator", args=(host, port, world_size)) + kwargs = {"method": "init_communicator", "args": (request.host, request.port, world_size)} + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + + return {"message": "Request received, initializing communicator"} + + class UpdateWeightsRequest(BaseModel): + name: str + dtype: str + shape: list[int] + + @app.post("/update_named_param/") + async def update_named_param(request: UpdateWeightsRequest): + """ + Updates the model weights with the provided tensor. + + Once this endpoint is called, the client process should broadcast the updated weights to all server workers. + + Args: + request (`UpdateWeightsRequest`): + - `name` (`str`): Name of the weight tensor being updated. + - `dtype` (`str`): Data type of the weight tensor (e.g., `"torch.float32"`). + - `shape` (list of `int`): Shape of the weight + + """ + # The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10)) + # So with collective_rpc we need to call it this way: + # llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10))) + dtype = torch.__getattribute__(request.dtype.split(".")[-1]) + kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))} + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + + return {"message": "Request received, updating named parameter"} + + @app.post("/reset_prefix_cache/") + async def reset_prefix_cache(): + """ + Resets the prefix cache for the model. + """ + for connection in connections: + connection.send({"type": "call", "method": "reset_prefix_cache"}) + # Wait for and collect all results + all_outputs = [connection.recv() for connection in connections] + success = all(output for output in all_outputs) + return {"message": "Request received, resetting prefix cache status: " + str(success)} + + @app.post("/close_communicator/") + async def close_communicator(): + """ + Closes the weight update group and cleans up associated resources. + """ + kwargs = {"method": "close_communicator"} + for connection in connections: + connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) + return {"message": "Request received, closing communicator"} + + # Start the server + uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) + + +def make_parser(subparsers: argparse._SubParsersAction = None): + if subparsers is not None: + parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments) + else: + parser = TrlParser(ScriptArguments) + return parser + + +if __name__ == "__main__": + parser = make_parser() + (script_args,) = parser.parse_args_and_config() + main(script_args) diff --git a/trl/templates/lm_model_card.md b/trl/templates/lm_model_card.md new file mode 100644 index 0000000000000000000000000000000000000000..1583c123f7569163cc276baaedeabc8c1ae69d1c --- /dev/null +++ b/trl/templates/lm_model_card.md @@ -0,0 +1,55 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_name }} + +This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" +generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") +output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] +print(output["generated_text"]) +``` + +## Training procedure + +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} +{% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} + +This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. + +### Framework versions + +- TRL: {{ trl_version }} +- Transformers: {{ transformers_version }} +- Pytorch: {{ pytorch_version }} +- Datasets: {{ datasets_version }} +- Tokenizers: {{ tokenizers_version }} + +## Citations + +{% if trainer_citation %}Cite {{ trainer_name }} as: + +```bibtex +{{ trainer_citation }} +```{% endif %} + +Cite TRL as: + +```bibtex +{% raw %}@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +}{% endraw %} +``` diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a773f911ad54496a938e012bb3dffd2b4cbe51f5 --- /dev/null +++ b/trl/trainer/__init__.py @@ -0,0 +1,165 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available + + +_import_structure = { + "alignprop_config": ["AlignPropConfig"], + "alignprop_trainer": ["AlignPropTrainer"], + "bco_config": ["BCOConfig"], + "bco_trainer": ["BCOTrainer"], + "callbacks": [ + "LogCompletionsCallback", + "MergeModelCallback", + "RichProgressCallback", + "SyncRefModelCallback", + "WinRateCallback", + ], + "cpo_config": ["CPOConfig"], + "cpo_trainer": ["CPOTrainer"], + "ddpo_config": ["DDPOConfig"], + "dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"], + "dpo_trainer": ["DPOTrainer"], + "gkd_config": ["GKDConfig"], + "gkd_trainer": ["GKDTrainer"], + "grpo_config": ["GRPOConfig"], + "grpo_trainer": ["GRPOTrainer"], + "iterative_sft_config": ["IterativeSFTConfig"], + "iterative_sft_trainer": ["IterativeSFTTrainer"], + "judges": [ + "AllTrueJudge", + "BaseBinaryJudge", + "BaseJudge", + "BasePairwiseJudge", + "BaseRankJudge", + "HfPairwiseJudge", + "OpenAIPairwiseJudge", + "PairRMJudge", + ], + "kto_config": ["KTOConfig"], + "kto_trainer": ["KTOTrainer"], + "model_config": ["ModelConfig"], + "nash_md_config": ["NashMDConfig"], + "nash_md_trainer": ["NashMDTrainer"], + "online_dpo_config": ["OnlineDPOConfig"], + "online_dpo_trainer": ["OnlineDPOTrainer"], + "orpo_config": ["ORPOConfig"], + "orpo_trainer": ["ORPOTrainer"], + "ppo_config": ["PPOConfig"], + "ppo_trainer": ["PPOTrainer"], + "prm_config": ["PRMConfig"], + "prm_trainer": ["PRMTrainer"], + "reward_config": ["RewardConfig"], + "reward_trainer": ["RewardTrainer"], + "rloo_config": ["RLOOConfig"], + "rloo_trainer": ["RLOOTrainer"], + "sft_config": ["SFTConfig"], + "sft_trainer": ["SFTTrainer"], + "utils": [ + "ConstantLengthDataset", + "DataCollatorForCompletionOnlyLM", + "RunningMoments", + "compute_accuracy", + "disable_dropout_in_model", + "empty_cache", + "peft_module_casting_to_bf16", + ], + "xpo_config": ["XPOConfig"], + "xpo_trainer": ["XPOTrainer"], +} +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["ddpo_trainer"] = ["DDPOTrainer"] + +if TYPE_CHECKING: + from .alignprop_config import AlignPropConfig + from .alignprop_trainer import AlignPropTrainer + from .bco_config import BCOConfig + from .bco_trainer import BCOTrainer + from .callbacks import ( + LogCompletionsCallback, + MergeModelCallback, + RichProgressCallback, + SyncRefModelCallback, + WinRateCallback, + ) + from .cpo_config import CPOConfig + from .cpo_trainer import CPOTrainer + from .ddpo_config import DDPOConfig + from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType + from .dpo_trainer import DPOTrainer + from .gkd_config import GKDConfig + from .gkd_trainer import GKDTrainer + from .grpo_config import GRPOConfig + from .grpo_trainer import GRPOTrainer + from .iterative_sft_trainer import IterativeSFTConfig, IterativeSFTTrainer + from .judges import ( + AllTrueJudge, + BaseBinaryJudge, + BaseJudge, + BasePairwiseJudge, + BaseRankJudge, + HfPairwiseJudge, + OpenAIPairwiseJudge, + PairRMJudge, + ) + from .kto_config import KTOConfig + from .kto_trainer import KTOTrainer + from .model_config import ModelConfig + from .nash_md_config import NashMDConfig + from .nash_md_trainer import NashMDTrainer + from .online_dpo_config import OnlineDPOConfig + from .online_dpo_trainer import OnlineDPOTrainer + from .orpo_config import ORPOConfig + from .orpo_trainer import ORPOTrainer + from .ppo_config import PPOConfig + from .ppo_trainer import PPOTrainer + from .prm_config import PRMConfig + from .prm_trainer import PRMTrainer + from .reward_config import RewardConfig + from .reward_trainer import RewardTrainer + from .rloo_config import RLOOConfig + from .rloo_trainer import RLOOTrainer + from .sft_config import SFTConfig + from .sft_trainer import SFTTrainer + from .utils import ( + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RunningMoments, + compute_accuracy, + disable_dropout_in_model, + empty_cache, + peft_module_casting_to_bf16, + ) + from .xpo_config import XPOConfig + from .xpo_trainer import XPOTrainer + + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .ddpo_trainer import DDPOTrainer +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/trainer/alignprop_config.py b/trl/trainer/alignprop_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6eef04df1d7d441a0819f00682db38c427d89f7c --- /dev/null +++ b/trl/trainer/alignprop_config.py @@ -0,0 +1,192 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import is_bitsandbytes_available + +from ..core import flatten_dict + + +@dataclass +class AlignPropConfig: + r""" + Configuration class for the [`AlignPropTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): + Name of this experiment (defaults to the file name without the extension). + run_name (`str`, *optional*, defaults to `""`): + Name of this run. + seed (`int`, *optional*, defaults to `0`): + Random seed for reproducibility. + log_with (`str` or `None`, *optional*, defaults to `None`): + Log with either `"wandb"` or `"tensorboard"`. Check + [tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details. + log_image_freq (`int`, *optional*, defaults to `1`): + Frequency for logging images. + tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): + Keyword arguments for the tracker (e.g., `wandb_project`). + accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator. + project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator project config (e.g., `logging_dir`). + tracker_project_name (`str`, *optional*, defaults to `"trl"`): + Name of project to use for tracking. + logdir (`str`, *optional*, defaults to `"logs"`): + Top-level logging directory for checkpoint saving. + num_epochs (`int`, *optional*, defaults to `100`): + Number of epochs to train. + save_freq (`int`, *optional*, defaults to `1`): + Number of epochs between saving model checkpoints. + num_checkpoint_limit (`int`, *optional*, defaults to `5`): + Number of checkpoints to keep before overwriting old ones. + mixed_precision (`str`, *optional*, defaults to `"fp16"`): + Mixed precision training. + allow_tf32 (`bool`, *optional*, defaults to `True`): + Allow `tf32` on Ampere GPUs. + resume_from (`str`, *optional*, defaults to `""`): + Path to resume training from a checkpoint. + sample_num_steps (`int`, *optional*, defaults to `50`): + Number of sampler inference steps. + sample_eta (`float`, *optional*, defaults to `1.0`): + Eta parameter for the DDIM sampler. + sample_guidance_scale (`float`, *optional*, defaults to `5.0`): + Classifier-free guidance weight. + train_batch_size (`int`, *optional*, defaults to `1`): + Batch size for training. + train_use_8bit_adam (`bool`, *optional*, defaults to `False`): + Whether to use the 8bit Adam optimizer from `bitsandbytes`. + train_learning_rate (`float`, *optional*, defaults to `1e-3`): + Learning rate. + train_adam_beta1 (`float`, *optional*, defaults to `0.9`): + Beta1 for Adam optimizer. + train_adam_beta2 (`float`, *optional*, defaults to `0.999`): + Beta2 for Adam optimizer. + train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): + Weight decay for Adam optimizer. + train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): + Epsilon value for Adam optimizer. + train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): + Number of gradient accumulation steps. + train_max_grad_norm (`float`, *optional*, defaults to `1.0`): + Maximum gradient norm for gradient clipping. + negative_prompts (`str` or `None`, *optional*, defaults to `None`): + Comma-separated list of prompts to use as negative examples. + truncated_backprop_rand (`bool`, *optional*, defaults to `True`): + If `True`, randomized truncation to different diffusion timesteps is used. + truncated_backprop_timestep (`int`, *optional*, defaults to `49`): + Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`. + truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`): + Range of diffusion timesteps for randomized truncated backpropagation. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the final model to the Hub. + """ + + exp_name: str = field( + default=os.path.basename(sys.argv[0])[: -len(".py")], + metadata={"help": "Name of this experiment (defaults to the file name without the extension)."}, + ) + run_name: str = field(default="", metadata={"help": "Name of this run."}) + seed: int = field(default=0, metadata={"help": "Random seed for reproducibility."}) + log_with: Optional[str] = field( + default=None, + metadata={"help": "Log with either 'wandb' or 'tensorboard'.", "choices": ["wandb", "tensorboard"]}, + ) + log_image_freq: int = field(default=1, metadata={"help": "Frequency for logging images."}) + tracker_kwargs: dict[str, Any] = field( + default_factory=dict, + metadata={"help": "Keyword arguments for the tracker (e.g., `wandb_project`)."}, + ) + accelerator_kwargs: dict[str, Any] = field( + default_factory=dict, metadata={"help": "Keyword arguments for the accelerator."} + ) + project_kwargs: dict[str, Any] = field( + default_factory=dict, + metadata={"help": "Keyword arguments for the accelerator project config (e.g., `logging_dir`)."}, + ) + tracker_project_name: str = field(default="trl", metadata={"help": "Name of project to use for tracking."}) + logdir: str = field(default="logs", metadata={"help": "Top-level logging directory for checkpoint saving."}) + num_epochs: int = field(default=100, metadata={"help": "Number of epochs to train."}) + save_freq: int = field(default=1, metadata={"help": "Number of epochs between saving model checkpoints."}) + num_checkpoint_limit: int = field( + default=5, metadata={"help": "Number of checkpoints to keep before overwriting old ones."} + ) + mixed_precision: str = field( + default="fp16", + metadata={ + "help": "Mixed precision training. Possible values are 'fp16', 'bf16', 'none'.", + "choices": ["fp16", "bf16", "none"], + }, + ) + allow_tf32: bool = field(default=True, metadata={"help": "Allow `tf32` on Ampere GPUs."}) + resume_from: str = field(default="", metadata={"help": "Path to resume training from a checkpoint."}) + sample_num_steps: int = field(default=50, metadata={"help": "Number of sampler inference steps."}) + sample_eta: float = field(default=1.0, metadata={"help": "Eta parameter for the DDIM sampler."}) + sample_guidance_scale: float = field(default=5.0, metadata={"help": "Classifier-free guidance weight."}) + train_batch_size: int = field(default=1, metadata={"help": "Batch size for training."}) + train_use_8bit_adam: bool = field( + default=False, metadata={"help": "Whether to use the 8bit Adam optimizer from `bitsandbytes`."} + ) + train_learning_rate: float = field(default=1e-3, metadata={"help": "Learning rate."}) + train_adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for Adam optimizer."}) + train_adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for Adam optimizer."}) + train_adam_weight_decay: float = field(default=1e-4, metadata={"help": "Weight decay for Adam optimizer."}) + train_adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon value for Adam optimizer."}) + train_gradient_accumulation_steps: int = field( + default=1, metadata={"help": "Number of gradient accumulation steps."} + ) + train_max_grad_norm: float = field(default=1.0, metadata={"help": "Maximum gradient norm for gradient clipping."}) + negative_prompts: Optional[str] = field( + default=None, + metadata={"help": "Comma-separated list of prompts to use as negative examples."}, + ) + truncated_backprop_rand: bool = field( + default=True, + metadata={"help": "If `True`, randomized truncation to different diffusion timesteps is used."}, + ) + truncated_backprop_timestep: int = field( + default=49, + metadata={ + "help": "Absolute timestep to which the gradients are backpropagated. Used only if " + "`truncated_backprop_rand=False`." + }, + ) + truncated_rand_backprop_minmax: tuple[int, int] = field( + default=(0, 50), + metadata={ + "help": "Range of diffusion timesteps for randomized truncated backpropagation.", + }, + ) + push_to_hub: bool = field(default=False, metadata={"help": "Whether to push the final model to the Hub."}) + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError( + "You need to install bitsandbytes to use 8bit Adam. " + "You can install it with `pip install bitsandbytes`." + ) diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..359174421242ca04480005f3b512b3d075e81981 --- /dev/null +++ b/trl/trainer/alignprop_trainer.py @@ -0,0 +1,461 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +from collections import defaultdict +from pathlib import Path +from typing import Any, Callable, Optional, Union +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import PyTorchModelHubMixin +from transformers import is_wandb_available + +from ..models import DDPOStableDiffusionPipeline +from .alignprop_config import AlignPropConfig +from .utils import generate_model_card, get_comet_experiment_url + + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__) + + +class AlignPropTrainer(PyTorchModelHubMixin): + """ + The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/ + As of now only Stable Diffusion based pipelines are supported + + Attributes: + config (`AlignPropConfig`): + Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details. + reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`): + Reward function to be used + prompt_function (`Callable[[], tuple[str, Any]]`): + Function to generate prompts to guide model + sd_pipeline (`DDPOStableDiffusionPipeline`): + Stable Diffusion pipeline to be used for training. + image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`): + Hook to be called to log images + """ + + _tag_names = ["trl", "alignprop"] + + def __init__( + self, + config: AlignPropConfig, + reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], + prompt_function: Callable[[], tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps, + **self.config.accelerator_kwargs, + ) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(alignprop_trainer_config=config.to_dict()) + if not is_using_tensorboard + else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer( + trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers + ) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: + unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) + else: + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs): + reward, reward_metadata = self.reward_fn( + prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"] + ) + return reward + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + """ + info = defaultdict(list) + + self.sd_pipeline.unet.train() + + for _ in range(self.config.train_gradient_accumulation_steps): + with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad(): + prompt_image_pairs = self._generate_samples( + batch_size=self.config.train_batch_size, + ) + + rewards = self.compute_rewards(prompt_image_pairs) + + prompt_image_pairs["rewards"] = rewards + + rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy() + + loss = self.calculate_loss(rewards) + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters() + if not isinstance(self.trainable_layers, list) + else self.trainable_layers, + self.config.train_max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + + info["reward_mean"].append(rewards_vis.mean()) + info["reward_std"].append(rewards_vis.std()) + info["loss"].append(loss.item()) + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + else: + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." + ) + # Logs generated images + if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0: + self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0]) + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, rewards): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + rewards (torch.Tensor): + Differentiable reward scalars for each generated image, shape: [batch_size] + + Returns: + loss (torch.Tensor) + (all of these are of shape (1,)) + """ + # Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739) + loss = 10.0 - (rewards).mean() + return loss + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, batch_size, with_grad=True, prompts=None): + """ + Generate samples from the model + + Args: + batch_size (int): Batch size to use for sampling + with_grad (bool): Whether the generated RGBs should have gradients attached to it. + + Returns: + prompt_image_pairs (dict[Any]) + """ + prompt_image_pairs = {} + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + if prompts is None: + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + else: + prompt_metadata = [{} for _ in range(batch_size)] + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + if with_grad: + sd_output = self.sd_pipeline.rgb_with_grad( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + truncated_backprop_rand=self.config.truncated_backprop_rand, + truncated_backprop_timestep=self.config.truncated_backprop_timestep, + truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax, + output_type="pt", + ) + else: + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + + prompt_image_pairs["images"] = images + prompt_image_pairs["prompts"] = prompts + prompt_image_pairs["prompt_metadata"] = prompt_metadata + + return prompt_image_pairs + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{prabhudesai2024aligning, + title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}}, + author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki}, + year = 2024, + eprint = {arXiv:2310.03739} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="AlignProp", + trainer_citation=citation, + paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation", + paper_id="2310.03739", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/bco_config.py b/trl/trainer/bco_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6438515eaeb3ab8be4c2258af2b396479a91d6 --- /dev/null +++ b/trl/trainer/bco_config.py @@ -0,0 +1,204 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class BCOConfig(TrainingArguments): + r""" + Configuration class for the [`BCOTrainer`]. + + This class includes only the parameters that are specific to BCO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during + evaluation. + is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute reference model log probabilities for training and evaluation datasets. This is + useful when training without the reference model to reduce the total GPU memory needed. + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model + from a string. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + prompt_sample_size (`int`, *optional*, defaults to `1024`): + Number of prompts that are fed to density ratio classifier. + min_density_ratio (`float`, *optional*, defaults to `0.5`): + Minimum value of the density ratio. The estimated density ratio is clamped to this value. + max_density_ratio (`float`, *optional*, defaults to `10.0`): + Maximum value of the density ratio. The estimated density ratio is clamped to this value. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the sequences (prompt + completion) in the batch. " + "This argument is required if you want to use the default data collator." + }, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. " + "This argument is required if you want to use the default data collator." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the " + "default data collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." + }, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long. Possible values are " + "`keep_end` or `keep_start`. This argument is required if you want to use the " + "default data collator." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model " + "to W&B during evaluation." + }, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " + "`model` argument, you need to specify if the model returned by the callable is an " + "encoder-decoder model." + }, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory " + "needed." + }, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "model from a string." + }, + ) + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "reference model from a string." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + prompt_sample_size: int = field( + default=1024, + metadata={"help": "Number of prompts that are fed to density ratio classifier."}, + ) + min_density_ratio: float = field( + default=0.5, + metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."}, + ) + max_density_ratio: float = field( + default=10.0, + metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."}, + ) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..39dbb01e58a44e35974444b0f306cfd6c9c9865b --- /dev/null +++ b/trl/trainer/bco_trainer.py @@ -0,0 +1,1528 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.logging import get_logger +from accelerate.utils import tqdm +from datasets import Dataset +from torch import autocast +from torch.utils.data import DataLoader, SequentialSampler +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainingArguments, + is_comet_available, + is_sklearn_available, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput, has_length +from transformers.utils import is_peft_available + +from ..data_utils import maybe_apply_chat_template +from ..import_utils import is_joblib_available +from ..models import create_reference_model, prepare_deepspeed +from .bco_config import BCOConfig +from .utils import ( + DPODataCollatorWithPadding, + RunningMoments, + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + +if is_sklearn_available(): + from sklearn.linear_model import LogisticRegression + +if is_joblib_available(): + import joblib + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +logger = get_logger(__name__) + +RUNNING_NAME = "running.json" +CLF_NAME = "clf.pkl" + + +def _tokenize( + batch: dict[str, list[Any]], + tokenizer: "PreTrainedTokenizer", + embedding_tokenizer: Optional["PreTrainedTokenizer"] = None, +) -> dict[str, list[Any]]: + """Tokenize a batch from a BCO specific dataset.""" + prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) + prompt_input_ids = prompt_tokenized["input_ids"] + prompt_attention_mask = prompt_tokenized["attention_mask"] + prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])] + full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) + full_input_ids = full_tokenized["input_ids"] + full_attention_mask = full_tokenized["attention_mask"] + + answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)] + answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)] + # Prepare input tokens for token by token comparison + full_input_ids = [np.array(f) for f in full_input_ids] + for full, concat in zip(full_input_ids, full_concat_input_ids): + if len(full) != len(concat): + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = [len(p) for p in prompt_input_ids] + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)): + if not np.array_equal(p, f[:r]): + response_token_ids_start_idx[idx] -= 1 + + prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)] + prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] + + for p, m in zip(prompt_input_ids, prompt_attention_mask): + if len(p) != len(m): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)] + answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] + + output = dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + answer_input_ids=answer_input_ids, + answer_attention_mask=answer_attention_mask, + ) + + if embedding_tokenizer is not None: + embedding_tokenized = embedding_tokenizer(batch["prompt"], truncation=True, add_special_tokens=False) + + output.update( + { + "embedding_input_ids": embedding_tokenized["input_ids"], + "embedding_attention_mask": embedding_tokenized["attention_mask"], + } + ) + + return output + + +def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: + """Process tokens of a BCO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + completion responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the completion. + + We also create the labels for the completion responses, which are of length equal to + the sum of the length of the prompt and the completion response, with + label_pad_token_id for the prompt tokens. + """ + prompt = example["prompt"] + completion = example["completion"] + + batch = { + f"{kwargs['prefix']}prompt": prompt, + f"{kwargs['prefix']}completion": completion, + f"{kwargs['prefix']}label": example["label"], + } + + if not kwargs["is_encoder_decoder"]: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") + + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } + + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["prompt_input_ids", "prompt_attention_mask"]: + if kwargs["truncation_mode"] == "keep_start": + all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] + elif kwargs["truncation_mode"] == "keep_end": + all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] + else: + raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") + + # if that's still too long, truncate the response + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = ( + all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + ) + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) + + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ + kwargs["label_pad_token_id"] + ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) + else: + completion_tokens = kwargs["tokenizer"]( + completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True + ) + prompt_tokens = kwargs["tokenizer"]( + prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True + ) + + batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] + + batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"] + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["completion_labels"]) + ) + + return batch + + +class BCOTrainer(Trainer): + r""" + Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + args (`BCOConfig`): + The arguments to use for training. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + data_collator (`transformers.DataCollator`, *optional*, defaults to `None`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + """ + + _tag_names = ["trl", "bco"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: BCOConfig = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + data_collator: Optional[DataCollator] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + model_adapter_name: Optional[str] = None, + ref_adapter_name: Optional[str] = None, + embedding_func: Optional[Callable] = None, + embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None, + ): + if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()): + raise ImportError( + "BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`." + ) + + if type(args) is TrainingArguments: + raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.") + + if not isinstance(model, str) and model is not None and ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the BCOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + + if args.ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the BCOTrainer. But your ref_model is already instantiated." + ) + else: + ref_model_init_kwargs = args.ref_model_init_kwargs + torch_dtype = ref_model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the BCOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + ref_model_init_kwargs["torch_dtype"] = torch_dtype + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if processing_class is None: + raise ValueError( + "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" + ) + if args.max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the `BCOConfig`. " + "It will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if args.max_length is not None: + max_length = args.max_length + + if args.max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the `BCOConfig`. " + "It will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + if args.max_prompt_length is not None: + max_prompt_length = args.max_prompt_length + + max_completion_length = None + if args.max_completion_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the BCOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_completion_length = 128 + if args.max_completion_length is not None and self.is_encoder_decoder: + max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your BCOConfig" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.max_completion_length = max_completion_length + self.precompute_ref_log_probs = args.precompute_ref_log_probs + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + # metric + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # BCO parameter + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + warnings.warn( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + UserWarning, + ) + + # Underlying Distribution Matching argument + self.embedding_func = embedding_func + self.embedding_tokenizer = embedding_tokenizer + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in BCO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + + with PartialState().main_process_first(): + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + + # Tokenize and prepare the training datasets + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, + num_proc=args.dataset_num_proc, + desc="Tokenizing train dataset", + ) + + # Prepare the datasets + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": self.is_encoder_decoder, + "tokenizer": processing_class, + "max_length": self.max_length, + "truncation_mode": self.truncation_mode, + "label_pad_token_id": self.label_pad_token_id, + "max_prompt_length": self.max_prompt_length, + "max_completion_length": self.max_completion_length, + } + train_dataset = train_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized train dataset", + ) + + if eval_dataset is not None: + # Tokenize + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": processing_class, "embedding_tokenizer": self.embedding_tokenizer}, + batched=True, + num_proc=args.dataset_num_proc, + desc="Tokenizing eval dataset", + ) + + # Process + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": self.is_encoder_decoder, + "tokenizer": processing_class, + "max_length": self.max_length, + "truncation_mode": self.truncation_mode, + "label_pad_token_id": self.label_pad_token_id, + "max_prompt_length": self.max_prompt_length, + "max_completion_length": self.max_completion_length, + } + eval_dataset = eval_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized eval dataset", + ) + + desirable = train_dataset.filter( + lambda x: x["label"], num_proc=args.dataset_num_proc, desc="Filtering desirable examples" + ) + undesirable = train_dataset.filter( + lambda x: not x["label"], num_proc=args.dataset_num_proc, desc="Filtering undesirable examples" + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + self.running = RunningMoments(accelerator=self.accelerator) + + if self.embedding_func is None or args.resume_from_checkpoint: + return + + chosen_embeddings = self._get_sample_prompt_embeddings(desirable, sample_size=self.args.prompt_sample_size) + rejected_embeddings = self._get_sample_prompt_embeddings(undesirable, sample_size=self.args.prompt_sample_size) + + embeddings = torch.cat((chosen_embeddings, rejected_embeddings), dim=0) + labels = torch.cat( + (torch.ones_like(chosen_embeddings[:, 0]), torch.zeros_like(rejected_embeddings[:, 0])), dim=0 + ) + + self.clf = LogisticRegression(class_weight="balanced").fit( + embeddings.cpu().float().numpy(), labels.cpu().numpy() + ) + chosen_mean = self.clf.score( + chosen_embeddings.cpu().float().numpy(), torch.ones_like(chosen_embeddings[:, 0]).cpu().numpy() + ) + rejected_mean = self.clf.score( + rejected_embeddings.cpu().float().numpy(), torch.zeros_like(rejected_embeddings[:, 0]).cpu().numpy() + ) + logger.info(f"UDM classifier training scores: chosen: {chosen_mean}, rejected: {rejected_mean}") + + @property + def match_underlying_distribution(self): + return self.embedding_func is not None and self.embedding_tokenizer is not None + + def _get_chosen_prob(self, prompt_embeddings: torch.FloatTensor) -> torch.FloatTensor: + """ + Calculates the probability if the given prompt embedding is from desirable dataset. + This function calculates the probability in the process and ensemble across processes. + """ + dtype = prompt_embeddings.dtype + device = prompt_embeddings.device + rank = self.accelerator.process_index + + padded_prompt_embeddings = self.accelerator.pad_across_processes( + prompt_embeddings, pad_index=self.embedding_tokenizer.pad_token_id + ) + sample_size = padded_prompt_embeddings.shape[0] + nonzero = padded_prompt_embeddings.mean(dim=1) != self.embedding_tokenizer.pad_token_id + prompt_embeddings = self.accelerator.gather(padded_prompt_embeddings) + + # cannot predict for all empty values + if prompt_embeddings.shape[0] == 0: + return torch.tensor([], device=device, dtype=dtype) + + prob = self.clf.predict_proba(prompt_embeddings.cpu().float().numpy())[:, 1] + prob = torch.as_tensor(prob, dtype=dtype, device=device) + prob = self.accelerator.reduce(prob, reduction="mean") + + prob = prob[sample_size * rank : sample_size * (rank + 1)] + prob = prob[nonzero] + + return prob + + def _vectorize_prompt(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor) -> torch.FloatTensor: + """ + Replaces processing_class.pad_token_id to embedding_tokenizer.pad_token_id + and applies self.embedding_func + """ + input_ids = torch.where( + input_ids == self.processing_class.pad_token_id, + self.embedding_tokenizer.pad_token_id, + input_ids, + ) + + with torch.no_grad(): + embeddings = self.embedding_func( + input_ids=input_ids, + attention_mask=attention_mask, + ) + + return embeddings + + def _get_prompt_embeddings( + self, batch: dict[str, Union[list, torch.LongTensor]] + ) -> tuple[torch.FloatTensor, torch.FloatTensor]: + """Extract embeddings from frozen embedding model""" + + if not self.match_underlying_distribution: + return None, None + + embeddings = self._vectorize_prompt( + input_ids=batch["embedding_input_ids"], + attention_mask=batch["embedding_attention_mask"], + ) + + chosen_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is True] + rejected_idx = [i for i in range(len(batch["label"])) if batch["label"][i] is False] + + chosen_embeddings = embeddings[chosen_idx, ...] + rejected_embeddings = embeddings[rejected_idx, ...] + + return (chosen_embeddings, rejected_embeddings) + + def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512) -> torch.FloatTensor: + """ + Sample instances from dataset and get prompt embeddings. + Used for density ratio classifier training. + """ + n_samples = min(len(dataset), sample_size) + rand_indices = np.random.choice(len(dataset), size=(n_samples,)) + + embedding_dataset = dataset.select(rand_indices) + + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(embedding_dataset, **dataloader_params)) + + with torch.no_grad(): + all_embeddings = torch.empty(0) + for padded_batch in tqdm(iterable=data_loader, desc="Building sample prompt embeddings"): + embeddings = self._vectorize_prompt( + input_ids=padded_batch["embedding_input_ids"], + attention_mask=padded_batch["embedding_attention_mask"], + ) + embeddings = self.accelerator.gather_for_metrics(embeddings) + all_embeddings = torch.cat((all_embeddings, embeddings.cpu())) + + return all_embeddings + + def _save_optimizer_and_scheduler(self, output_dir): + output_dir = output_dir if output_dir is not None else self.args.output_dir + super()._save_optimizer_and_scheduler(output_dir) + + if self.accelerator.is_main_process: + # When saving optimizer and scheduler to checkpoint, save also the running delta object. + self.running.save_to_json(os.path.join(output_dir, RUNNING_NAME)) + + if self.match_underlying_distribution: + joblib.dump(self.clf, os.path.join(output_dir, CLF_NAME), compress=True) + + def _load_optimizer_and_scheduler(self, checkpoint): + if checkpoint is None: + logger.warning_once(f"Missing Checkpoint {checkpoint}") + return + + super()._load_optimizer_and_scheduler(checkpoint) + + # when loading optimizer and scheduler from checkpoint, also load the running delta object. + running_file = os.path.join(checkpoint, RUNNING_NAME) + if os.path.isfile(running_file): + self.running = RunningMoments.load_from_json(self.accelerator, running_file) + + if self.match_underlying_distribution: + clf_file = os.path.join(checkpoint, CLF_NAME) + if os.path.isfile(clf_file): + self.clf = joblib.load(clf_file) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + reference_completion_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_completion_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + self.train_dataset = self.train_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_completion_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_completion_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + eval_dataset = eval_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def compute_reference_log_probs(self, padded_batch: dict) -> dict: + """Computes log probabilities of the reference model for a single padded batch of a BCO specific dataset.""" + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + if self.is_encoder_decoder: + completion_logits = self.model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + else: + completion_logits = self.model( + padded_batch["completion_input_ids"], + attention_mask=padded_batch["completion_attention_mask"], + ).logits + + else: + if self.is_encoder_decoder: + completion_logits = self.ref_model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + else: + completion_logits = self.ref_model( + padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, + padded_batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + return completion_logps + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + else: + # Fixes end-dec RuntimeError + labels = labels.clone() + + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + model_kwargs = ( + { + "labels": batch["completion_labels"], + "decoder_input_ids": batch.get("completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else {} + ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + **model_kwargs, + ) + completion_logits = outputs.logits + + completion_logps = self.get_batch_logps( + completion_logits, + batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + if completion_logps.shape[0] != len(batch["label"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) + + chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] + + chosen_logps = completion_logps[chosen_idx, ...] + rejected_logps = completion_logps[rejected_idx, ...] + + chosen_logits = completion_logits[chosen_idx, ...] + rejected_logits = completion_logits[rejected_idx, ...] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, outputs.aux_loss) + else: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def _get_udm_weight(self, rejected_embeddings: torch.FloatTensor) -> torch.FloatTensor: + prob_desirable = self._get_chosen_prob(rejected_embeddings) + min_ratio = self.args.min_density_ratio + max_ratio = self.args.max_density_ratio + + weight = (prob_desirable / (1 - prob_desirable + 1e-8)).clamp(min=min_ratio, max=max_ratio) + + return weight + + def bco_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + chosen_embeddings: Optional[torch.FloatTensor], + rejected_embeddings: Optional[torch.FloatTensor], + do_train: bool = True, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the BCO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,) + chosen_embeddings: embeddings of desirable prompts + rejected_embeddings: embeddings of undesirable prompts + + Returns: + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, delta). + The losses tensor contains the BCO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The delta value contains the moving average of all implicit rewards. + """ + + chosen_logratios = policy_chosen_logps - reference_chosen_logps + chosen_rewards = self.beta * chosen_logratios + + rejected_logratios = policy_rejected_logps - reference_rejected_logps + rejected_rewards = self.beta * rejected_logratios + + if do_train: + self.running.update(torch.cat((chosen_rewards, rejected_rewards), 0).detach()) + delta = torch.as_tensor(self.running.mean, device=chosen_rewards.device) + + chosen_losses = -F.logsigmoid(chosen_rewards - delta) + rejected_losses = -F.logsigmoid(-(rejected_rewards - delta)) + + if self.match_underlying_distribution: + chosen_weight = torch.ones_like(chosen_losses) + rejected_weight = self._get_udm_weight(rejected_embeddings) + + losses = torch.cat((chosen_weight * chosen_losses, rejected_weight * rejected_losses), dim=0) + else: + losses = torch.cat((chosen_losses, rejected_losses), dim=0) + + return losses, chosen_rewards, rejected_rewards, delta + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + do_train: bool = True, + ): + """Compute the BCO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + forward_output = self.forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = forward_output[:4] + if self.aux_loss_enabled: + aux_loss = forward_output[4] + + # if reference_logps in batch use them, otherwise use the reference model + if "reference_logps" in batch: + chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] + + reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] + reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.forward(self.model, batch)[:4] + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.forward(self.ref_model, batch)[:4] + + chosen_embeddings, rejected_embeddings = self._get_prompt_embeddings(batch) + + losses, chosen_rewards, rejected_rewards, delta = self.bco_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + chosen_embeddings, + rejected_embeddings, + do_train=do_train, + ) + metrics["delta"] = self.accelerator.gather_for_metrics(delta).mean().item() + + num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) + num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) + + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) + metrics["count/rejected"] = all_num_rejected + + loss = losses.nanmean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: + if dataset is None: + dataset = self.train_dataset + if dataset is None or not has_length(dataset): + return None + return SequentialSampler(dataset) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) + reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, do_train=False) + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = {} + if "logits/chosen_sum" in metrics: + logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] + if "logits/rejected_sum" in metrics: + logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False] + target_batch = { + "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies], + "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], + "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), + } + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float` or `None`, *optional*, defaults to `None`): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{jung2024binary, + title = {{Binary Classifier Optimization for Large Language Model Alignment}}, + author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On}, + year = 2024, + eprint = {arXiv:2404.04656} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="BCO", + trainer_citation=citation, + paper_title="Binary Classifier Optimization for Large Language Model Alignment", + paper_id="2404.04656", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..7fd60fb34aa71bb2bced3b812280082189696c9c --- /dev/null +++ b/trl/trainer/callbacks.py @@ -0,0 +1,570 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional, Union + +import pandas as pd +import torch +from accelerate import Accelerator +from accelerate.state import AcceleratorState +from accelerate.utils import gather_object, is_wandb_available +from transformers import ( + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) +from transformers.trainer_utils import has_length +from transformers.utils import is_rich_available + +from ..data_utils import maybe_apply_chat_template +from ..import_utils import is_mergekit_available +from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf +from ..models.utils import unwrap_model_for_generation +from .judges import BasePairwiseJudge +from .utils import log_table_to_comet_experiment + + +if is_rich_available(): + from rich.console import Console, Group + from rich.live import Live + from rich.panel import Panel + from rich.progress import Progress + +if is_wandb_available(): + import wandb + + +def _generate_completions( + prompts: list[str], + model: PreTrainedModel, + tokenizer: PreTrainedTokenizerBase, + accelerator: Accelerator, + generation_config: Optional[GenerationConfig], + batch_size: int = 1, +) -> list[str]: + """ + Generates completions for a list of pre-formatted prompts from the given model. + + Args: + prompts (list[str]): A list of input prompts for which completions are to be generated. + model (PreTrainedModel): The pre-trained model to be used for generation. + tokenizer (PreTrainedTokenizerBase): The tokenizer to be used for encoding and decoding. + accelerator (Accelerator): The accelerator to be used for model execution. + generation_config (GenerationConfig): Configuration for text generation. + batch_size (int, optional): The number of prompts to process in each batch. Default is 1. + + Returns: + list[str]: A list of generated text completions corresponding to the input prompts. + """ + completions = [] + with unwrap_model_for_generation(model, accelerator) as unwrapped_model: + for idx in range(0, len(prompts), batch_size): + batch = prompts[idx : idx + batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(model.device) + generations = unwrapped_model.generate( + **tokenized_batch, + generation_config=generation_config, + ) + for prompt, generation in zip(tokenized_batch.input_ids, generations): + # Remove prompt from generation + generation = generation[len(prompt) :] + completion = tokenizer.decode(generation, skip_special_tokens=True) + completions.append(completion) + return completions + + +class SyncRefModelCallback(TrainerCallback): + """ + Callback to synchronize the model with a reference model. + """ + + def __init__( + self, + ref_model: Union[PreTrainedModel, torch.nn.Module], + accelerator: Optional[Accelerator], + ): + self.accelerator = accelerator + self.ref_model = ref_model + + @staticmethod + def _sync_target_model(model, target_model, alpha): + for target_param, copy_param in zip(target_model.parameters(), model.parameters()): + target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) + + @staticmethod + def sync_target_model(model, target_model, alpha): + deepspeed_plugin = AcceleratorState().deepspeed_plugin + if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: + import deepspeed + + with deepspeed.zero.GatheredParameters( + list(model.parameters()) + list(target_model.parameters()), modifier_rank=0 + ): + if deepspeed.comm.get_rank() == 0: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + else: + SyncRefModelCallback._sync_target_model(model, target_model, alpha) + + def on_step_end(self, args, state, control, **kwargs): + model: PreTrainedModel = kwargs["model"] + + if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: + if self.accelerator: + model = self.accelerator.unwrap_model(model) + self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) + + +class RichProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation using Rich. + """ + + def __init__(self): + if not is_rich_available(): + raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.") + + self.training_bar = None + self.prediction_bar = None + + self.training_task_id = None + self.prediction_task_id = None + + self.rich_group = None + self.rich_console = None + + self.training_status = None + self.current_step = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar = Progress() + self.prediction_bar = Progress() + + self.rich_console = Console() + + self.training_status = self.rich_console.status("Nothing to log yet ...") + + self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status))) + self.rich_group.start() + + self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True) + self.current_step = state.global_step + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if state.is_world_process_zero and has_length(eval_dataloader): + if self.prediction_task_id is None: + self.prediction_task_id = self.prediction_bar.add_task( + "[blue]Predicting on the evaluation dataset", total=len(eval_dataloader) + ) + self.prediction_bar.update(self.prediction_task_id, advance=1, update=True) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_task_id is not None: + self.prediction_bar.remove_task(self.prediction_task_id) + self.prediction_task_id = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_task_id is not None: + self.prediction_bar.remove_task(self.prediction_task_id) + self.prediction_task_id = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + _ = logs.pop("total_flos", None) + self.training_status.update(f"[bold green]Status = {str(logs)}") + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.rich_group.stop() + + self.training_bar = None + self.prediction_bar = None + self.training_task_id = None + self.prediction_task_id = None + self.rich_group = None + self.rich_console = None + self.training_status = None + self.current_step = None + + +def _win_rate_completions_df( + state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str] +) -> pd.DataFrame: + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions, winner_indices)) + # Split completions from reference model and policy + split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data] + return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"]) + + +class WinRateCallback(TrainerCallback): + """ + A [`~transformers.TrainerCallback`] that computes the win rate of a model based on a reference. + + It generates completions using prompts from the evaluation dataset and compares the trained model's outputs against + a reference. The reference is either the initial version of the model (before training) or the reference model, if + available in the trainer. During each evaluation step, a judge determines how often the trained model's completions + win against the reference using a judge. The win rate is then logged in the trainer's logs under the key + `"eval_win_rate"`. + + Usage: + ```python + trainer = DPOTrainer(...) + judge = PairRMJudge() + win_rate_callback = WinRateCallback(judge=judge, trainer=trainer) + trainer.add_callback(win_rate_callback) + ``` + + Args: + judge (`BasePairwiseJudge`): + The judge to use for comparing completions. + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. If the `Trainer` has a reference model (via the + `ref_model` attribute), it will use this reference model for generating the reference completions; + otherwise, it defaults to using the initial model. + generation_config (`GenerationConfig`, *optional*): + The generation config to use for generating completions. + num_prompts (`int` or `None`, *optional*, defaults to `None`): + The number of prompts to generate completions for. If not provided, defaults to the number of examples + in the evaluation dataset. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions before judging. + use_soft_judge (`bool`, *optional*, defaults to `False`): + Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the + second. + """ + + def __init__( + self, + judge: BasePairwiseJudge, + trainer: Trainer, + generation_config: Optional[GenerationConfig] = None, + num_prompts: Optional[int] = None, + shuffle_order: bool = True, + use_soft_judge: bool = False, + ): + self.judge = judge + self.trainer = trainer + self.shuffle_order = shuffle_order + self.generation_config = generation_config + self.ref_completions = [] + self.use_soft_judge = use_soft_judge + + if self.trainer.eval_dataset is None: + raise ValueError("Trainer must have an evaluation dataset to use the WinRateCallback.") + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # When the trainer is initialized, we generate completions for the reference model. + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + # Use the reference model if available, otherwise use the initial model + model = getattr(self.trainer, "ref_model", None) + # At this point, there are two cases where `ref_model` is None: + # 1. The method doesn't require a reference model. + # 2. The method uses a reference model, but `ref_model` is set to None. + # This occurs when using PEFT, where the reference model can be obtained by simply disabling the model's adapter. + # In theory, we should disable the adapter here, but since it's zero-initialized at the start of training, + # the model behaves identically with or without the adapter. + # Therefore, there's no need to explicitly disable it at this point. + if model is None: + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + self.ref_completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + # Compute initial win rate as a reference point + completions = list(zip(self.ref_completions, self.ref_completions)) + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + prompts = gather_object(prompts) + completions = gather_object(completions) + winner_indices = gather_object(winner_indices) + + # Logging + if self.trainer.accelerator.is_main_process: + win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) + + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # At every evaluation step, we generate completions for the model and compare them with the reference + # completions that have been generated at the beginning of training. We then compute the win rate and log it to + # the trainer. + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + + completions = list(zip(self.ref_completions, completions)) + + if self.use_soft_judge: + ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True) + winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs] + ref_win_probs = gather_object(ref_win_probs) + else: + winner_indices = self.judge.judge(prompts, completions, self.shuffle_order) + prompts = gather_object(prompts) + completions = gather_object(completions) + winner_indices = gather_object(winner_indices) + + # Logging + if self.trainer.accelerator.is_main_process: + win_rate = sum(winner_idx == 1 for winner_idx in winner_indices) / len(winner_indices) + if self.use_soft_judge: + avg_win_prob = 1.0 - sum(ref_win_probs) / len(ref_win_probs) + self.trainer.log({"eval_avg_win_prob": avg_win_prob, "eval_win_rate": win_rate}) + else: + self.trainer.log({"eval_win_rate": win_rate}) + + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + wandb.log({"win_rate_completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + df = _win_rate_completions_df( + state=state, + prompts=prompts, + completions=completions, + winner_indices=winner_indices, + ) + log_table_to_comet_experiment( + name="win_rate_completions.csv", + table=df, + ) + + +class LogCompletionsCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that logs completions to Weights & Biases and/or Comet. + + Usage: + ```python + trainer = DPOTrainer(...) + completions_callback = LogCompletionsCallback(trainer=trainer) + trainer.add_callback(completions_callback) + ``` + + Args: + trainer (`Trainer`): + Trainer to which the callback will be attached. The trainer's evaluation dataset must include a `"prompt"` + column containing the prompts for generating completions. + generation_config (`GenerationConfig`, *optional*): + The generation config to use for generating completions. + num_prompts (`int` or `None`, *optional*): + The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset. + freq (`int` or `None`, *optional*): + The frequency at which to log completions. If not provided, defaults to the trainer's `eval_steps`. + """ + + def __init__( + self, + trainer: Trainer, + generation_config: Optional[GenerationConfig] = None, + num_prompts: Optional[int] = None, + freq: Optional[int] = None, + ): + self.trainer = trainer + self.generation_config = generation_config + self.freq = freq + self.table = [] + self._last_logged_step = -1 + + if self.trainer.eval_dataset is None: + raise ValueError("Trainer must have an evaluation dataset to use the LogCompletionsCallback.") + else: + self.eval_dataset = self.trainer.eval_dataset + + if num_prompts is not None: + self.eval_dataset = self.eval_dataset.select(range(num_prompts)) + + def on_step_end(self, args, state, control, **kwargs): + # Only log once per step (this method may be called multiple times) + if state.global_step == self._last_logged_step: + return + + # Only log every `freq` steps (if no `freq` is provided, log every `eval_steps` steps) + freq = self.freq or state.eval_steps + if state.global_step % freq != 0: + return + + tokenizer = kwargs["processing_class"] + tokenizer.padding_side = "left" + accelerator = self.trainer.accelerator + model = self.trainer.model_wrapped + with accelerator.split_between_processes(self.eval_dataset["prompt"]) as prompts: + prompts = [maybe_apply_chat_template({"prompt": prompt}, tokenizer)["prompt"] for prompt in prompts] + completions = _generate_completions( + prompts, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + generation_config=self.generation_config, + batch_size=args.per_device_eval_batch_size, + ) + completions = gather_object(completions) + prompts = gather_object(prompts) + + # Build the data to log + if self.trainer.accelerator.is_main_process: + global_step = [str(state.global_step)] * len(prompts) + data = list(zip(global_step, prompts, completions)) + self.table.extend(data) + table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table) + + if "wandb" in args.report_to: + wandb.log({"completions": table}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=table, + ) + + # Save the last logged step, so we don't log the same completions multiple times + self._last_logged_step = state.global_step + + +class MergeModelCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based on a merge configuration. + + Args: + merge_config ([`MergeConfig`], *optional*, defaults to `None`): + Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. + merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to merge the model at every checkpoint. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the merged model to the Hub after merging. + + Example: + + ```python + !pip install mergekit + + from trl.mergekit_utils import MergeConfig + from trl import MergeModelCallback + + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer(..., callbacks=[merge_callback]) + ``` + """ + + def __init__( + self, + merge_config: Optional["MergeConfig"] = None, + merge_at_every_checkpoint: bool = False, + push_to_hub: bool = False, + ): + if not is_mergekit_available(): + raise ImportError( + "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." + ) + self.merge_config = merge_config or MergeConfig() + self.merge_at_every_checkpoint = merge_at_every_checkpoint + self.push_to_hub = push_to_hub + + def _merge_and_maybe_push(self, output_dir, global_step, model): + checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") + self.merge_config.policy_model_path = checkpoint_path + if self.merge_config.target_model_path is None: + self.merge_config.target_model_path = model.config._name_or_path + merge_path = os.path.join(checkpoint_path, "merged") + + merge_models(self.merge_config.create(), merge_path) + + if self.push_to_hub: + repo_name = f"{output_dir}_checkpoint-{global_step}_merged" + upload_model_to_hf(merge_path, repo_name) + + def on_save(self, args, state, control, model=None, **kwargs): + if self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) + + def on_train_end(self, args, state, control, model=None, **kwargs): + if not self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) diff --git a/trl/trainer/cpo_config.py b/trl/trainer/cpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..594f4aedf9e74fee8fd0677e844d7bcc620b61c6 --- /dev/null +++ b/trl/trainer/cpo_config.py @@ -0,0 +1,189 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class CPOConfig(TrainingArguments): + r""" + Configuration class for the [`CPOTrainer`]. + + This class includes only the parameters that are specific to CPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). + label_smoothing (`float`, *optional*, defaults to `0.0`): + Label smoothing factor. This argument is required if you want to use the default data collator. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper. + + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + cpo_alpha (`float`, *optional*, defaults to `1.0`): + Weight of the BC regularizer in CPO training. + simpo_gamma (`float`, *optional*, defaults to `0.5`): + Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`,*optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " + "the reference model." + }, + ) + label_smoothing: float = field( + default=0.0, + metadata={"help": "Label smoothing factor."}, + ) + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": ["sigmoid", "hinge", "ipo", "simpo"], + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + cpo_alpha: float = field( + default=1.0, + metadata={"help": "Weight of the BC regularizer in CPO training."}, + ) + simpo_gamma: float = field( + default=0.5, + metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={"help": "Label pad token id."}, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={"help": "Whether the model is an encoder-decoder model."}, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a3250117d42d8d45abb056ce97038264aac3aec9 --- /dev/null +++ b/trl/trainer/cpo_trainer.py @@ -0,0 +1,1097 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Callable, Literal, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from datasets import Dataset +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_comet_available, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt +from .cpo_config import CPOConfig +from .utils import ( + DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + + +class CPOTrainer(Trainer): + r""" + Initialize CPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`CPOConfig`): + The CPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + """ + + _tag_names = ["trl", "cpo"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: Optional[CPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the CPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a CPO dataset.") + if args.max_length is None: + warnings.warn( + "`max_length` is not set in the CPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + max_length = args.max_length + if args.max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the CPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + else: + max_prompt_length = args.max_prompt_length + + if not max_prompt_length < max_length: + raise ValueError( + f"max_prompt_length ({max_prompt_length}) should be strictly less than max_length ({max_length})." + ) + + if args.max_completion_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_completion_length` in the CPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_completion_length = 128 + else: + max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.max_completion_length = max_completion_length + self.processing_class = processing_class + + if args.loss_type in ["hinge", "ipo"] and args.label_smoothing > 0: + warnings.warn( + f"You are using the {args.loss_type} loss type that does not support label smoothing. The " + "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", + UserWarning, + ) + if args.loss_type == "kto_pair": + raise ValueError("Support for kto_pair has been removed in CPOTrainer. Please use KTOTrainer.") + + self.beta = args.beta + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.cpo_alpha = args.cpo_alpha + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + warnings.warn( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + UserWarning, + ) + + if args.loss_type == "simpo": + self.simpo_gamma = args.simpo_gamma + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in CPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + + # tokenize the dataset + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict: + """Tokenize a single row from a CPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, Union[list, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def cpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the CPO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the CPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + """ + logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device) + + # The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and + # calculates a conservative CPO loss. + + if self.loss_type == "simpo": + gamma_logratios = self.simpo_gamma / self.beta + logits = logits - gamma_logratios + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "sigmoid": + # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'simpo']" + ) + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + labels = concatenated_batch["concatenated_labels"].clone() + + if self.cpo_alpha == 0: + nll_loss = torch.tensor(0.0).to(self.accelerator.device) + else: + nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=self.loss_type in ["ipo", "simpo"], + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the CPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards = self.cpo_loss( + policy_chosen_logps, + policy_rejected_logps, + ) + + loss = losses.mean() + self.cpo_alpha * policy_nll_loss + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean()).mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean()).mean().item() + ) + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item() + + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float` or `None`, *optional*, defaults to `None`): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @inproceedings{xu2024contrastive, + title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}}, + author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=51iwkioZpn} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="CPO", + trainer_citation=citation, + paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation", + paper_id="2401.08417", + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/ddpo_config.py b/trl/trainer/ddpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..5852e8582c0eab0438d45b7c20afeab2f1bfbc5e --- /dev/null +++ b/trl/trainer/ddpo_config.py @@ -0,0 +1,299 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from dataclasses import dataclass, field +from typing import Optional + +from transformers import is_bitsandbytes_available + +from ..core import flatten_dict + + +@dataclass +class DDPOConfig: + r""" + Configuration class for the [`DDPOTrainer`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`): + Name of this experiment (by default is the file name without the extension name). + run_name (`str`, *optional*, defaults to `""`): + Name of this run. + seed (`int`, *optional*, defaults to `0`): + Random seed. + log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`): + Log with either 'wandb' or 'tensorboard', check + https://huggingface.co/docs/accelerate/usage_guides/tracking for more details. + tracker_kwargs (`Dict`, *optional*, defaults to `{}`): + Keyword arguments for the tracker (e.g. wandb_project). + accelerator_kwargs (`Dict`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator. + project_kwargs (`Dict`, *optional*, defaults to `{}`): + Keyword arguments for the accelerator project config (e.g. `logging_dir`). + tracker_project_name (`str`, *optional*, defaults to `"trl"`): + Name of project to use for tracking. + logdir (`str`, *optional*, defaults to `"logs"`): + Top-level logging directory for checkpoint saving. + num_epochs (`int`, *optional*, defaults to `100`): + Number of epochs to train. + save_freq (`int`, *optional*, defaults to `1`): + Number of epochs between saving model checkpoints. + num_checkpoint_limit (`int`, *optional*, defaults to `5`): + Number of checkpoints to keep before overwriting old ones. + mixed_precision (`str`, *optional*, defaults to `"fp16"`): + Mixed precision training. + allow_tf32 (`bool`, *optional*, defaults to `True`): + Allow `tf32` on Ampere GPUs. + resume_from (`str`, *optional*, defaults to `""`): + Resume training from a checkpoint. + sample_num_steps (`int`, *optional*, defaults to `50`): + Number of sampler inference steps. + sample_eta (`float`, *optional*, defaults to `1.0`): + Eta parameter for the DDIM sampler. + sample_guidance_scale (`float`, *optional*, defaults to `5.0`): + Classifier-free guidance weight. + sample_batch_size (`int`, *optional*, defaults to `1`): + Batch size (per GPU) to use for sampling. + sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`): + Number of batches to sample per epoch. + train_batch_size (`int`, *optional*, defaults to `1`): + Batch size (per GPU) to use for training. + train_use_8bit_adam (`bool`, *optional*, defaults to `False`): + Use 8bit Adam optimizer from bitsandbytes. + train_learning_rate (`float`, *optional*, defaults to `3e-4`): + Learning rate. + train_adam_beta1 (`float`, *optional*, defaults to `0.9`): + Adam beta1. + train_adam_beta2 (`float`, *optional*, defaults to `0.999`): + Adam beta2. + train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`): + Adam weight decay. + train_adam_epsilon (`float`, *optional*, defaults to `1e-8`): + Adam epsilon. + train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`): + Number of gradient accumulation steps. + train_max_grad_norm (`float`, *optional*, defaults to `1.0`): + Maximum gradient norm for gradient clipping. + train_num_inner_epochs (`int`, *optional*, defaults to `1`): + Number of inner epochs per outer epoch. + train_cfg (`bool`, *optional*, defaults to `True`): + Whether to use classifier-free guidance during training. + train_adv_clip_max (`float`, *optional*, defaults to `5.0`): + Clip advantages to the range. + train_clip_range (`float`, *optional*, defaults to `1e-4`): + PPO clip range. + train_timestep_fraction (`float`, *optional*, defaults to `1.0`): + Fraction of timesteps to train on. + per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`): + Whether to track statistics for each prompt separately. + per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`): + Number of reward values to store in the buffer for each prompt. + per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`): + Minimum number of reward values to store in the buffer. + async_reward_computation (`bool`, *optional*, defaults to `False`): + Whether to compute rewards asynchronously. + max_workers (`int`, *optional*, defaults to `2`): + Maximum number of workers to use for async reward computation. + negative_prompts (`str`, *optional*, defaults to `""`): + Comma-separated list of prompts to use as negative examples. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the final model checkpoint to the Hub. + """ + + exp_name: str = field( + default=os.path.basename(sys.argv[0])[: -len(".py")], + metadata={"help": "Name of this experiment (by default is the file name without the extension name)."}, + ) + run_name: str = field( + default="", + metadata={"help": "Name of this run."}, + ) + seed: int = field( + default=0, + metadata={"help": "Random seed."}, + ) + log_with: Optional[str] = field( + default=None, + metadata={ + "help": "Log with either 'wandb' or 'tensorboard'.", + "choices": ["wandb", "tensorboard"], + }, + ) + tracker_kwargs: dict = field( + default_factory=dict, + metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)."}, + ) + accelerator_kwargs: dict = field( + default_factory=dict, + metadata={"help": "Keyword arguments for the accelerator."}, + ) + project_kwargs: dict = field( + default_factory=dict, + metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)."}, + ) + tracker_project_name: str = field( + default="trl", + metadata={"help": "Name of project to use for tracking."}, + ) + logdir: str = field( + default="logs", + metadata={"help": "Top-level logging directory for checkpoint saving."}, + ) + num_epochs: int = field( + default=100, + metadata={"help": "Number of epochs to train."}, + ) + save_freq: int = field( + default=1, + metadata={"help": "Number of epochs between saving model checkpoints."}, + ) + num_checkpoint_limit: int = field( + default=5, + metadata={"help": "Number of checkpoints to keep before overwriting old ones."}, + ) + mixed_precision: str = field( + default="fp16", + metadata={"help": "Mixed precision training."}, + ) + allow_tf32: bool = field( + default=True, + metadata={"help": "Allow `tf32` on Ampere GPUs."}, + ) + resume_from: str = field( + default="", + metadata={"help": "Resume training from a checkpoint."}, + ) + sample_num_steps: int = field( + default=50, + metadata={"help": "Number of sampler inference steps."}, + ) + sample_eta: float = field( + default=1.0, + metadata={"help": "Eta parameter for the DDIM sampler."}, + ) + sample_guidance_scale: float = field( + default=5.0, + metadata={"help": "Classifier-free guidance weight."}, + ) + sample_batch_size: int = field( + default=1, + metadata={"help": "Batch size (per GPU) to use for sampling."}, + ) + sample_num_batches_per_epoch: int = field( + default=2, + metadata={"help": "Number of batches to sample per epoch."}, + ) + train_batch_size: int = field( + default=1, + metadata={"help": "Batch size (per GPU) to use for training."}, + ) + train_use_8bit_adam: bool = field( + default=False, + metadata={"help": "Use 8bit Adam optimizer from bitsandbytes."}, + ) + train_learning_rate: float = field( + default=3e-4, + metadata={"help": "Learning rate."}, + ) + train_adam_beta1: float = field( + default=0.9, + metadata={"help": "Adam beta1."}, + ) + train_adam_beta2: float = field( + default=0.999, + metadata={"help": "Adam beta2."}, + ) + train_adam_weight_decay: float = field( + default=1e-4, + metadata={"help": "Adam weight decay."}, + ) + train_adam_epsilon: float = field( + default=1e-8, + metadata={"help": "Adam epsilon."}, + ) + train_gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of gradient accumulation steps."}, + ) + train_max_grad_norm: float = field( + default=1.0, + metadata={"help": "Maximum gradient norm for gradient clipping."}, + ) + train_num_inner_epochs: int = field( + default=1, + metadata={"help": "Number of inner epochs per outer epoch."}, + ) + train_cfg: bool = field( + default=True, + metadata={"help": "Whether to use classifier-free guidance during training."}, + ) + train_adv_clip_max: float = field( + default=5.0, + metadata={"help": "Clip advantages to the range."}, + ) + train_clip_range: float = field( + default=1e-4, + metadata={"help": "PPO clip range."}, + ) + train_timestep_fraction: float = field( + default=1.0, + metadata={"help": "Fraction of timesteps to train on."}, + ) + per_prompt_stat_tracking: bool = field( + default=False, + metadata={"help": "Whether to track statistics for each prompt separately."}, + ) + per_prompt_stat_tracking_buffer_size: int = field( + default=16, + metadata={"help": "Number of reward values to store in the buffer for each prompt."}, + ) + per_prompt_stat_tracking_min_count: int = field( + default=16, + metadata={"help": "Minimum number of reward values to store in the buffer."}, + ) + async_reward_computation: bool = field( + default=False, + metadata={"help": "Whether to compute rewards asynchronously."}, + ) + max_workers: int = field( + default=2, + metadata={"help": "Maximum number of workers to use for async reward computation."}, + ) + negative_prompts: str = field( + default="", + metadata={"help": "Comma-separated list of prompts to use as negative examples."}, + ) + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the final model checkpoint to the Hub."}, + ) + + def to_dict(self): + output_dict = {} + for key, value in self.__dict__.items(): + output_dict[key] = value + return flatten_dict(output_dict) + + def __post_init__(self): + if self.train_use_8bit_adam and not is_bitsandbytes_available(): + raise ImportError( + "You need to install bitsandbytes to use 8bit Adam. " + "You can install it with `pip install bitsandbytes`." + ) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d800e9058ef9c87d094f913a22ce9fd18af060df --- /dev/null +++ b/trl/trainer/ddpo_trainer.py @@ -0,0 +1,664 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +from collections import defaultdict +from concurrent import futures +from pathlib import Path +from typing import Any, Callable, Optional, Union +from warnings import warn + +import torch +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from huggingface_hub import PyTorchModelHubMixin +from transformers import is_wandb_available + +from ..models import DDPOStableDiffusionPipeline +from .ddpo_config import DDPOConfig +from .utils import PerPromptStatTracker, generate_model_card, get_comet_experiment_url + + +if is_wandb_available(): + import wandb + + +logger = get_logger(__name__) + + +class DDPOTrainer(PyTorchModelHubMixin): + """ + The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models. + Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch + As of now only Stable Diffusion based pipelines are supported + + Attributes: + **config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more + details. + **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used + **prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model + **sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training. + **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images + """ + + _tag_names = ["trl", "ddpo"] + + def __init__( + self, + config: DDPOConfig, + reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor], + prompt_function: Callable[[], tuple[str, Any]], + sd_pipeline: DDPOStableDiffusionPipeline, + image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None, + ): + if image_samples_hook is None: + warn("No image_samples_hook provided; no images will be logged") + + self.prompt_fn = prompt_function + self.reward_fn = reward_function + self.config = config + self.image_samples_callback = image_samples_hook + + accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs) + + if self.config.resume_from: + self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from)) + if "checkpoint_" not in os.path.basename(self.config.resume_from): + # get the most recent checkpoint in this directory + checkpoints = list( + filter( + lambda x: "checkpoint_" in x, + os.listdir(self.config.resume_from), + ) + ) + if len(checkpoints) == 0: + raise ValueError(f"No checkpoints found in {self.config.resume_from}") + checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints]) + self.config.resume_from = os.path.join( + self.config.resume_from, + f"checkpoint_{checkpoint_numbers[-1]}", + ) + + accelerator_project_config.iteration = checkpoint_numbers[-1] + 1 + + # number of timesteps within each trajectory to train on + self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction) + + self.accelerator = Accelerator( + log_with=self.config.log_with, + mixed_precision=self.config.mixed_precision, + project_config=accelerator_project_config, + # we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the + # number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get + # the total number of optimizer steps to accumulate across. + gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps, + **self.config.accelerator_kwargs, + ) + + is_okay, message = self._config_check() + if not is_okay: + raise ValueError(message) + + is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard" + + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + self.config.tracker_project_name, + config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(), + init_kwargs=self.config.tracker_kwargs, + ) + + logger.info(f"\n{config}") + + set_seed(self.config.seed, device_specific=True) + + self.sd_pipeline = sd_pipeline + + self.sd_pipeline.set_progress_bar_config( + position=1, + disable=not self.accelerator.is_local_main_process, + leave=False, + desc="Timestep", + dynamic_ncols=True, + ) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + if self.accelerator.mixed_precision == "fp16": + inference_dtype = torch.float16 + elif self.accelerator.mixed_precision == "bf16": + inference_dtype = torch.bfloat16 + else: + inference_dtype = torch.float32 + + self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype) + self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype) + + trainable_layers = self.sd_pipeline.get_trainable_layers() + + self.accelerator.register_save_state_pre_hook(self._save_model_hook) + self.accelerator.register_load_state_pre_hook(self._load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if self.config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + self.optimizer = self._setup_optimizer( + trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers + ) + + self.neg_prompt_embed = self.sd_pipeline.text_encoder( + self.sd_pipeline.tokenizer( + [""] if self.config.negative_prompts is None else self.config.negative_prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + )[0] + + if config.per_prompt_stat_tracking: + self.stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking_buffer_size, + config.per_prompt_stat_tracking_min_count, + ) + + # NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses + # more memory + self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast + + if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora: + unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters())) + else: + self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer) + + if self.config.async_reward_computation: + self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers) + + if config.resume_from: + logger.info(f"Resuming from {config.resume_from}") + self.accelerator.load_state(config.resume_from) + self.first_epoch = int(config.resume_from.split("_")[-1]) + 1 + else: + self.first_epoch = 0 + + def compute_rewards(self, prompt_image_pairs, is_async=False): + if not is_async: + rewards = [] + for images, prompts, prompt_metadata in prompt_image_pairs: + reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata) + rewards.append( + ( + torch.as_tensor(reward, device=self.accelerator.device), + reward_metadata, + ) + ) + else: + rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs) + rewards = [ + (torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result()) + for reward, reward_metadata in rewards + ] + + return zip(*rewards) + + def step(self, epoch: int, global_step: int): + """ + Perform a single step of training. + + Args: + epoch (int): The current epoch. + global_step (int): The current global step. + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker. + + Returns: + global_step (int): The updated global step. + + """ + samples, prompt_image_data = self._generate_samples( + iterations=self.config.sample_num_batches_per_epoch, + batch_size=self.config.sample_batch_size, + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + rewards, rewards_metadata = self.compute_rewards( + prompt_image_data, is_async=self.config.async_reward_computation + ) + + for i, image_data in enumerate(prompt_image_data): + image_data.extend([rewards[i], rewards_metadata[i]]) + + if self.image_samples_callback is not None: + self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0]) + + rewards = torch.cat(rewards) + rewards = self.accelerator.gather(rewards).cpu().numpy() + + self.accelerator.log( + { + "reward": rewards, + "epoch": epoch, + "reward_mean": rewards.mean(), + "reward_std": rewards.std(), + }, + step=global_step, + ) + + if self.config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = self.stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index] + .to(self.accelerator.device) + ) + + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + + for inner_epoch in range(self.config.train_num_inner_epochs): + # shuffle samples along batch dimension + perm = torch.randperm(total_batch_size, device=self.accelerator.device) + samples = {k: v[perm] for k, v in samples.items()} + + # shuffle along time dimension independently for each sample + # still trying to understand the code below + perms = torch.stack( + [torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)] + ) + + for key in ["timesteps", "latents", "next_latents", "log_probs"]: + samples[key] = samples[key][ + torch.arange(total_batch_size, device=self.accelerator.device)[:, None], + perms, + ] + + original_keys = samples.keys() + original_values = samples.values() + # rebatch them as user defined train_batch_size is different from sample_batch_size + reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values] + + # Transpose the list of original values + transposed_values = zip(*reshaped_values) + # Create new dictionaries for each row of transposed values + samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values] + + self.sd_pipeline.unet.train() + global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched) + # ensure optimization step at the end of the inner epoch + if not self.accelerator.sync_gradients: + raise ValueError( + "Optimization step should have been performed by this point. Please check calculated gradient accumulation settings." + ) + + if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process: + self.accelerator.save_state() + + return global_step + + def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds): + """ + Calculate the loss for a batch of an unpacked sample + + Args: + latents (torch.Tensor): + The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + timesteps (torch.Tensor): + The timesteps sampled from the diffusion model, shape: [batch_size] + next_latents (torch.Tensor): + The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width] + log_probs (torch.Tensor): + The log probabilities of the latents, shape: [batch_size] + advantages (torch.Tensor): + The advantages of the latents, shape: [batch_size] + embeds (torch.Tensor): + The embeddings of the prompts, shape: [2*batch_size or batch_size, ...] + Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds + + Returns: + loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor) + (all of these are of shape (1,)) + """ + with self.autocast(): + if self.config.train_cfg: + noise_pred = self.sd_pipeline.unet( + torch.cat([latents] * 2), + torch.cat([timesteps] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = self.sd_pipeline.unet( + latents, + timesteps, + embeds, + ).sample + # compute the log prob of next_latents given latents under the current model + + scheduler_step_output = self.sd_pipeline.scheduler_step( + noise_pred, + timesteps, + latents, + eta=self.config.sample_eta, + prev_sample=next_latents, + ) + + log_prob = scheduler_step_output.log_probs + + advantages = torch.clamp( + advantages, + -self.config.train_adv_clip_max, + self.config.train_adv_clip_max, + ) + + ratio = torch.exp(log_prob - log_probs) + + loss = self.loss(advantages, self.config.train_clip_range, ratio) + + approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2) + + clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float()) + + return loss, approx_kl, clipfrac + + def loss( + self, + advantages: torch.Tensor, + clip_range: float, + ratio: torch.Tensor, + ): + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - clip_range, + 1.0 + clip_range, + ) + return torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + def _setup_optimizer(self, trainable_layers_parameters): + if self.config.train_use_8bit_adam: + import bitsandbytes + + optimizer_cls = bitsandbytes.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + return optimizer_cls( + trainable_layers_parameters, + lr=self.config.train_learning_rate, + betas=(self.config.train_adam_beta1, self.config.train_adam_beta2), + weight_decay=self.config.train_adam_weight_decay, + eps=self.config.train_adam_epsilon, + ) + + def _save_model_hook(self, models, weights, output_dir): + self.sd_pipeline.save_checkpoint(models, weights, output_dir) + weights.pop() # ensures that accelerate doesn't try to handle saving of the model + + def _load_model_hook(self, models, input_dir): + self.sd_pipeline.load_checkpoint(models, input_dir) + models.pop() # ensures that accelerate doesn't try to handle loading of the model + + def _generate_samples(self, iterations, batch_size): + """ + Generate samples from the model + + Args: + iterations (int): Number of iterations to generate samples for + batch_size (int): Batch size to use for sampling + + Returns: + samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]]) + """ + samples = [] + prompt_image_pairs = [] + self.sd_pipeline.unet.eval() + + sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1) + + for _ in range(iterations): + prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)]) + + prompt_ids = self.sd_pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.sd_pipeline.tokenizer.model_max_length, + ).input_ids.to(self.accelerator.device) + prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0] + + with self.autocast(): + sd_output = self.sd_pipeline( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=self.config.sample_num_steps, + guidance_scale=self.config.sample_guidance_scale, + eta=self.config.sample_eta, + output_type="pt", + ) + + images = sd_output.images + latents = sd_output.latents + log_probs = sd_output.log_probs + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "negative_prompt_embeds": sample_neg_prompt_embeds, + } + ) + prompt_image_pairs.append([images, prompts, prompt_metadata]) + + return samples, prompt_image_pairs + + def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples): + """ + Train on a batch of samples. Main training segment + + Args: + inner_epoch (int): The current inner epoch + epoch (int): The current epoch + global_step (int): The current global step + batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on + + Side Effects: + - Model weights are updated + - Logs the statistics to the accelerator trackers. + + Returns: + global_step (int): The updated global step + """ + info = defaultdict(list) + for _i, sample in enumerate(batched_samples): + if self.config.train_cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in range(self.num_train_timesteps): + with self.accelerator.accumulate(self.sd_pipeline.unet): + loss, approx_kl, clipfrac = self.calculate_loss( + sample["latents"][:, j], + sample["timesteps"][:, j], + sample["next_latents"][:, j], + sample["log_probs"][:, j], + sample["advantages"], + embeds, + ) + info["approx_kl"].append(approx_kl) + info["clipfrac"].append(clipfrac) + info["loss"].append(loss) + + self.accelerator.backward(loss) + if self.accelerator.sync_gradients: + self.accelerator.clip_grad_norm_( + self.trainable_layers.parameters() + if not isinstance(self.trainable_layers, list) + else self.trainable_layers, + self.config.train_max_grad_norm, + ) + self.optimizer.step() + self.optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if self.accelerator.sync_gradients: + # log training-related stuff + info = {k: torch.mean(torch.stack(v)) for k, v in info.items()} + info = self.accelerator.reduce(info, reduction="mean") + info.update({"epoch": epoch, "inner_epoch": inner_epoch}) + self.accelerator.log(info, step=global_step) + global_step += 1 + info = defaultdict(list) + return global_step + + def _config_check(self) -> tuple[bool, str]: + samples_per_epoch = ( + self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch + ) + total_train_batch_size = ( + self.config.train_batch_size + * self.accelerator.num_processes + * self.config.train_gradient_accumulation_steps + ) + + if not self.config.sample_batch_size >= self.config.train_batch_size: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})", + ) + if not self.config.sample_batch_size % self.config.train_batch_size == 0: + return ( + False, + f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})", + ) + if not samples_per_epoch % total_train_batch_size == 0: + return ( + False, + f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})", + ) + return True, "" + + def train(self, epochs: Optional[int] = None): + """ + Train the model for a given number of epochs + """ + global_step = 0 + if epochs is None: + epochs = self.config.num_epochs + for epoch in range(self.first_epoch, epochs): + global_step = self.step(epoch, global_step) + + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @inproceedings{black2024training, + title = {{Training Diffusion Models with Reinforcement Learning}}, + author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=YCWjhGrJFD}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="DDPO", + trainer_citation=citation, + paper_title="Training Diffusion Models with Reinforcement Learning", + paper_id="2305.13301", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/dpo_config.py b/trl/trainer/dpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff7dc2e6384c3f07cc5fb68ee38f2f924c3a11a --- /dev/null +++ b/trl/trainer/dpo_config.py @@ -0,0 +1,436 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Optional, Union + +from transformers import TrainingArguments + + +class FDivergenceType(Enum): + REVERSE_KL = "reverse_kl" + JS_DIVERGENCE = "js_divergence" + ALPHA_DIVERGENCE = "alpha_divergence" + + +class FDivergenceConstants: + ALPHA_DIVERGENCE_COEF_KEY = "alpha_divergence_coef" + ALPHA_DIVERGENCE_COEF_DEFAULT = 1.0 + + +@dataclass +class DPOConfig(TrainingArguments): + r""" + Configuration class for the [`DPOTrainer`]. + + This class includes only the parameters that are specific to DPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of the + [`DPOTrainer`] is provided as a string. + ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument of the + [`DPOTrainer`] is provided as a string. + model_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + force_use_ref_model (`bool`, *optional*, defaults to `False`): + If you provide a PEFT model as the active model and wish to use a different model for the `ref_model`, set + this flag to `True`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + use_logits_to_keep (`bool`, *optional*, defaults to `False`): + If `True`, only a specified number of logits are computed in the forward pass. This can be useful for + saving memory and speeding up training by not computing the logits for all tokens, especially in + scenarios when working with very long prompts where labels are ignored (-100). + + > Parameters that control the data preprocessing + + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Padding value to use for labels. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the full sequence (prompt + completion). + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened + batch structure. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute the log probabilities from the reference model. Setting this to `True` allows + training without needing the reference model during training, which can help reduce GPU memory usage. If + set to `False` (default), the reference model will be used during training to compute log probabilities + on-the-fly. + precompute_ref_batch_size (`int` or `None`, *optional*, defaults to `None`): + Batch size to use when precomputing reference model log probabilities. This can be set higher than the + training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for + training and `per_device_eval_batch_size` for evaluation. + tools (`Optional[list[Union[dict, Callable]]]`, *optional*, defaults to `None`): + List of tools (callable functions) that will be accessible to the model. + If the template does not support function calling, this argument will have no effect. + + > Parameters that control the training + + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"exo_pair"`: pairwise EXO loss from the [EXO](https://huggingface.co/papers/2402.00856) paper. + - `"nca_pair"`: pairwise NCA loss from the [NCA](https://huggingface.co/papers/2402.05369) paper. + - `"robust"`: unbiased estimate of the DPO loss that is robust to preference noise from the [Robust DPO](https://huggingface.co/papers/2403.00409) paper. + - `"bco_pair"`: pairwise BCO loss from the [BCO](https://huggingface.co/papers/2404.04656) paper. + - `"sppo_hard"`: SPPO loss with hard label from the [SPPO](https://huggingface.co/papers/2405.00675) paper. + - `"aot"`: AOT loss for paired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"aot_pair"`: AOT loss for unpaired datasets from the [AOT](https://huggingface.co/papers/2406.05882) paper. + - `"discopop"`: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper. + - `"apo_zero"`: APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + - `"apo_down"`: APO-down loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). + f_divergence_type (`str`, *optional*, defaults to `FDivergenceType.REVERSE_KL`): + Type of f-divergence regularization function to compute divergence between policy and reference model. + f_alpha_divergence_coef (`float`, *optional*, defaults to `1.0`): + α coefficient in the α-divergence u^-α regularization function for DPO loss. + reference_free (`bool`, *optional*, defaults to `False`): + Whether to ignore the provided reference model and implicitly use a reference model that assigns equal + probability to all responses. + label_smoothing (`float`, *optional*, defaults to `0.0`): + Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and + [Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`. + use_weighting (`bool`, *optional*, defaults to `False`): + Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827). + rpo_alpha (`float`, *optional*, defaults to `None`): + α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the + weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the + DPO loss. The paper recommends `rpo_alpha=1.0`. + ld_alpha (`float` or `None`, *optional*, defaults to `None`): + α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting + of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose + part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between + `0.0` and `1.0`. + discopop_tau (`float`, *optional*, defaults to `0.05`): + τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls + the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originites from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + + > Parameters that control the logging + + generate_during_eval (`bool`, *optional*, defaults to `False`): + Whether to generate and log completions from both the model and the reference model to W&B or Comet during + evaluation. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `DPOTrainer` is provided as a string." + }, + ) + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument " + "of the `DPOTrainer` is provided as a string." + }, + ) + model_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, + ) + ref_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + force_use_ref_model: bool = field( + default=False, + metadata={ + "help": "If you provide a PEFT model as the active model and wish to use a different model for the " + "`ref_model`, set this flag to `True`." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + use_logits_to_keep: bool = field( + default=False, + metadata={ + "help": "If `True`, only a specified number of logits are computed in the forward pass. This can be " + "useful for saving memory and speeding up training by not computing the logits for all tokens, especially " + "in scenarios when working with very long prompts where labels are ignored (-100)." + }, + ) + + # Parameters that control the data preprocessing + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={"help": "Padding value to use for labels."}, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={"help": "Maximum length of the prompt."}, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={"help": "Maximum length of the completion."}, + ) + max_length: Optional[int] = field( + default=1024, + metadata={"help": "Maximum length of the full sequence (prompt + completion)."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` " + "and `'keep_start'`.", + "choices": ["keep_end", "keep_start"], + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, " + "this is only supported with the `flash_attention_2` attention implementation, which can efficiently " + "handle the flattened batch structure." + }, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute the log probabilities from the reference model. Setting this to `True` " + "allows training without needing the reference model during training, which can help reduce GPU memory " + "usage. If set to `False` (default), the reference model will be used during training to compute log " + "probabilities on-the-fly." + }, + ) + precompute_ref_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "Batch size to use when precomputing reference model log probabilities. This can be set higher " + "than the training batch size to speed up preprocessing. If `None`, defaults to " + "`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation." + }, + ) + tools: Optional[list[Union[dict, Callable]]] = field( + default=None, + metadata={ + "help": "List of tools (callable functions) that will be accessible to the model. If the template does " + "not support function calling, this argument will have no effect." + }, + ) + + # Parameters that control the training + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": [ + "sigmoid", + "hinge", + "ipo", + "exo_pair", + "nca_pair", + "robust", + "bco_pair", + "sppo_hard", + "aot", + "aot_pair", + "discopop", + "apo_zero", + "apo_down", + ], + }, + ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_loss` is `True`." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. " + "Higher β means less deviation from the reference model." + }, + ) + f_divergence_type: FDivergenceType = field( + default=FDivergenceType.REVERSE_KL, + metadata={ + "help": "Type of f-divergence regularization function to compute divergence between policy and reference " + "model." + }, + ) + f_alpha_divergence_coef: float = field( + default=1.0, + metadata={"help": "α coefficient in the α-divergence u^-α regularization function for DPO loss."}, + ) + reference_free: bool = field( + default=False, + metadata={ + "help": "Whether to ignore the provided reference model and implicitly use a reference model that assigns " + "equal probability to all responses." + }, + ) + label_smoothing: float = field( + default=0.0, + metadata={ + "help": "Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should " + "be between `0.0` and `0.5`." + }, + ) + use_weighting: bool = field( + default=False, + metadata={"help": "Whether to weight the loss as done in the WPO paper."}, + ) + rpo_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. " + "If `None`, no weighting is applied and the loss is the same as the DPO loss. The paper recommends " + "`rpo_alpha=1.0`." + }, + ) + ld_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token " + "log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is " + "equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.", + }, + ) + discopop_tau: float = field( + default=0.05, + metadata={ + "help": "τ/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated " + "loss. The paper recommends the default value `discopop_tau=0.05`." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Parameters that control the logging + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "Whether to generate and log completions from both the model and the reference model to W&B or " + "Comet during evaluation." + }, + ) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e1c20e064b3dcb2a45ec5567b404237bc5003f --- /dev/null +++ b/trl/trainer/dpo_trainer.py @@ -0,0 +1,1917 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Literal, Optional, Union + +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.utils import tqdm +from datasets import Dataset, IterableDataset +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_comet_available, + is_wandb_available, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_liger_kernel_available, is_peft_available + +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt +from ..models import create_reference_model, prepare_deepspeed +from ..models.utils import prepare_fsdp +from .callbacks import SyncRefModelCallback +from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType +from .utils import ( + RunningMoments, + cap_exp, + disable_dropout_in_model, + empty_cache, + flush_left, + flush_right, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + pad, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearDPOLoss + + +if is_wandb_available(): + import wandb + + +def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) -> torch.Tensor: + """Shift input ids one token to the right, and pad with pad_token_id""" + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + +@dataclass +class DataCollatorForPreference(DataCollatorMixin): + """ + Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they + are not all of the same length. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl import DataCollatorForPreference + >>> collator = DataCollatorForPreference(pad_token_id=0) + >>> examples = [ + ... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, + ... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]} + ... ] + >>> collator(examples) + {'prompt_input_ids': tensor([[1, 2, 3], + [0, 7, 8]]), + 'prompt_attention_mask': tensor([[1, 1, 1], + [0, 1, 1]]), + 'chosen_input_ids': tensor([[ 4, 5], + [ 9, 10]]), + 'chosen_attention_mask': tensor([[1, 1], + [1, 1]]), + 'rejected_input_ids': tensor([[ 6, 0, 0], + [11, 12, 13]]), + 'rejected_attention_mask': tensor([[1, 0, 0], + [1, 1, 1]]) + } + ``` + """ + + pad_token_id: int + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + # Convert to tensor + prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples] + prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids] + chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples] + chosen_attention_mask = [torch.ones_like(input_ids) for input_ids in chosen_input_ids] + rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples] + rejected_attention_mask = [torch.ones_like(input_ids) for input_ids in rejected_input_ids] + if "pixel_values" in examples[0]: + pixel_values = [torch.tensor(example["pixel_values"]) for example in examples] + if "pixel_attention_mask" in examples[0]: + pixel_attention_mask = [torch.tensor(example["pixel_attention_mask"]) for example in examples] + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples]) + ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples]) + + # Pad + output = {} + output["prompt_input_ids"] = pad(prompt_input_ids, padding_value=self.pad_token_id, padding_side="left") + output["prompt_attention_mask"] = pad(prompt_attention_mask, padding_value=0, padding_side="left") + output["chosen_input_ids"] = pad(chosen_input_ids, padding_value=self.pad_token_id) + output["chosen_attention_mask"] = pad(chosen_attention_mask, padding_value=0) + output["rejected_input_ids"] = pad(rejected_input_ids, padding_value=self.pad_token_id) + output["rejected_attention_mask"] = pad(rejected_attention_mask, padding_value=0) + if "pixel_values" in examples[0]: + output["pixel_values"] = pad(pixel_values, padding_value=0.0) + if "pixel_attention_mask" in examples[0]: + output["pixel_attention_mask"] = pad(pixel_attention_mask, padding_value=0) + if "image_sizes" in examples[0]: + output["image_sizes"] = torch.tensor([example["image_sizes"] for example in examples]) + if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]: + output["ref_chosen_logps"] = ref_chosen_logps + output["ref_rejected_logps"] = ref_rejected_logps + + return output + + +class DPOTrainer(Trainer): + """ + Trainer for Direct Preference Optimization (DPO) method. + + This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + args ([`DPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator (`DataCollator`, *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`DataCollatorForPreference`]. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. DPO supports [preference](#preference) type and. The format of the samples can + be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. *Note* When passing TrainingArgs with `batch_eval_metrics` set to + `True`, your compute_metrics function must take a boolean `compute_result` argument. This will be triggered + after the last eval batch to signal that the function needs to calculate and return the global summary + statistics rather than accumulating the batch-level statistics. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "dpo"] + + def __init__( + self, + model: Union[str, nn.Module, PreTrainedModel], + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: Optional[DPOConfig] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + ): + # Args + model_id = model if isinstance(model, str) else model.config._name_or_path + if args is None: + model_name = model_id.split("/")[-1] + args = DPOConfig(f"{model_name}-DPO") + + # Handle the tokenizer + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model_id) + + if args.padding_value is not None: + self.padding_value = args.padding_value + else: + if hasattr(processing_class, "pad_token_id") and processing_class.pad_token_id is not None: + self.padding_value = processing_class.pad_token_id + elif hasattr(processing_class, "tokenizer") and processing_class.tokenizer.pad_token_id is not None: + self.padding_value = processing_class.tokenizer.pad_token_id + else: + raise ValueError( + "`padding_value` is not specified in `DPOConfig`, and `pad_token_id` is missing in the " + "`processing_class`. Please either set the `padding_value` argument in `DPOConfig`, or set " + "`tokenizer.pad_token` (e.g., `tokenizer.pad_token = tokenizer.eos_token`) before instantiating " + "the trainer." + ) + + # Model + if not isinstance(model, str) and ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + + if args.model_init_kwargs is not None and not isinstance(model, str): + warnings.warn( + "You passed model_init_kwargs to the `DPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + if isinstance(model, str): + model = self._create_model_from_path(model, args) + + if args.ref_model_init_kwargs is not None and not isinstance(ref_model, str): + warnings.warn( + "You passed ref_model_init_kwargs to the `DPOConfig`, but your ref_model is already instantiated. " + "The `ref_model_init_kwargs` will be ignored." + ) + if isinstance(ref_model, str): + ref_model = self._create_model_from_path(ref_model, args, is_ref=True) + + # PEFT configuration and model wrapping + model = self._prepare_peft_model(model, ref_model, peft_config, args) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + self.is_encoder_decoder = model.config.is_encoder_decoder + self.is_vision_model = model.config.model_type in MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES.keys() + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = args.model_adapter_name + self.ref_adapter_name = args.ref_adapter_name + self.reference_free = args.reference_free + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Liger kernel + if args.use_liger_loss: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if args.loss_type != "sigmoid": + raise ValueError( + "You set `use_liger_loss=True` but the loss type is not `sigmoid`. " + "Please set `loss_type='sigmoid'` to use the liger kernel." + ) + self.dpo_loss_fn = LigerFusedLinearDPOLoss( + ignore_index=args.label_pad_token_id, + beta=args.beta, + use_ref_model=not args.reference_free, + average_log_prob=False, + ) + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in DPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Data collator + if data_collator is None: + data_collator = DataCollatorForPreference(pad_token_id=self.padding_value) + + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode + self.precompute_ref_log_probs = args.precompute_ref_log_probs + self.use_logits_to_keep = args.use_logits_to_keep + + if args.padding_free: + if model.config._attn_implementation != "flash_attention_2": + warnings.warn( + "Padding-free training is enabled, but the attention implementation is not set to " + "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " + "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " + "other implementations may lead to unexpected behavior. To ensure compatibility, set " + "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " + "attention mechanism can handle flattened sequences." + ) + if args.per_device_train_batch_size == 1: + warnings.warn( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." + ) + self.padding_free = args.padding_free + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + if ( + args.loss_type in ["hinge", "ipo", "bco_pair", "sppo_hard", "nca_pair", "apo_zero", "apo_down"] + and args.label_smoothing > 0 + ): + warnings.warn( + f"You are using the {args.loss_type} loss type that does not support label smoothing. The " + "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning.", + UserWarning, + ) + if args.loss_type == "kto_pair": + raise ValueError("Support for kto_pair has been removed in DPOTrainer. Please use KTOTrainer.") + + self.beta = args.beta + self.label_smoothing = args.label_smoothing + self.loss_type = args.loss_type + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.use_weighting = args.use_weighting + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + warnings.warn( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + UserWarning, + ) + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + self.f_divergence_type = args.f_divergence_type + self.f_divergence_params = {FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY: args.f_alpha_divergence_coef} + self.dataset_num_proc = args.dataset_num_proc + + # Dataset preparation + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + if args.sync_ref_model: + raise ValueError( + "You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`." + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + if self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`." + ) + + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + if self.loss_type == "bco_pair": + self.running = RunningMoments(self.accelerator) + + def _create_model_from_path(self, model_path: str, args: DPOConfig, is_ref: bool = False) -> PreTrainedModel: + """Creates a model from a path or model identifier.""" + if not is_ref: + model_init_kwargs = args.model_init_kwargs or {} + else: + model_init_kwargs = args.ref_model_init_kwargs or {} + + # Handle torch dtype + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `DPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + # if args.gradient_checkpointing: + # model_init_kwargs["use_cache"] = False + + # Create model + model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return model + + def _prepare_peft_model( + self, model: PreTrainedModel, ref_model: PreTrainedModel, peft_config: Any, args: DPOConfig + ) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if ref_model is not None and not args.force_use_ref_model: + raise ValueError( + "You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference" + " model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init." + " if you want to use a different ref_model." + ) + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + else: + model = self._prepare_gradient_checkpointing(model, args) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + else: + model = self._prepare_gradient_checkpointing(model, args) + + return model + + def _prepare_gradient_checkpointing(self, model: PreTrainedModel, args: DPOConfig): + """Prepare the gradienting checkpointing for the model.""" + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + if args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model + + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: DPOConfig, + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size + map_kwargs["num_proc"] = args.dataset_num_proc + map_kwargs["writer_batch_size"] = 10 + + with PartialState().main_process_first(): + # Extract prompt if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset" + dataset = dataset.map(maybe_extract_prompt, **map_kwargs) + + # Apply the chat template if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset" + dataset = dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + dataset = dataset.map( + self.tokenize_row if not self.is_vision_model else self.process_row, + remove_columns=["chosen", "rejected"], + fn_kwargs={ + "processing_class": processing_class, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + # for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token]) + "add_special_tokens": False, + }, + **map_kwargs, + ) + + return dataset + + @staticmethod + def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + """ + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"chosen"`, and `"rejected"`. + processing_class (`PreTrainedTokenizerBase`): + Processing class used to process the data. + max_prompt_length (`int` or `None`): + Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + add_special_tokens (`bool`): + Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, + the prompt sequence will have a bos token prepended and an eos token appended. In any case, the + completion sequences will have an eos token appended. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"prompt_input_ids"`, `"chosen_input_ids"`, and + `"rejected_input_ids". + + Example: + ```python + >>> from transformers import GPT2Tokenizer + >>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + >>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"} + >>> DPOTrainer.tokenize_row( + ... features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False + ... ) + {'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]} + ``` + """ + tokenizer = processing_class # the processing class is a tokenizer + prompt_input_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token_id is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token_id is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + return { + "prompt_input_ids": prompt_input_ids, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + @staticmethod + def process_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens): + """ + Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information. + """ + processor, tokenizer = processing_class, processing_class.tokenizer # the processing class is a processor + processed_features = processor(images=features["images"], text=features["prompt"], add_special_tokens=False) + + prompt_input_ids = processed_features["input_ids"][0] + pixel_values = processed_features["pixel_values"][0] + chosen_input_ids = tokenizer(features["chosen"], add_special_tokens=False)["input_ids"] + rejected_input_ids = tokenizer(features["rejected"], add_special_tokens=False)["input_ids"] + + # Add special tokens (typically for encoder-decoder models) + if add_special_tokens: + if tokenizer.bos_token_id is not None: + prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids + if tokenizer.eos_token_id is not None: + prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id] + chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id] + rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id] + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_input_ids = prompt_input_ids[-max_prompt_length:] + if max_completion_length is not None: + chosen_input_ids = chosen_input_ids[:max_completion_length] + rejected_input_ids = rejected_input_ids[:max_completion_length] + + output = { + "prompt_input_ids": prompt_input_ids, + "pixel_values": pixel_values, + "chosen_input_ids": chosen_input_ids, + "rejected_input_ids": rejected_input_ids, + } + + if "pixel_attention_mask" in processed_features: + output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0] + if "image_sizes" in processed_features: + output["image_sizes"] = processed_features["image_sizes"][0] + + return output + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In DPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by `DataCollatorForPreference`, hence the override. + if self._signature_columns is None: + self._signature_columns = [ + "prompt_input_ids", + "chosen_input_ids", + "rejected_input_ids", + "image_sizes", + "ref_chosen_logps", + "ref_rejected_logps", + ] + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size + dataloader_params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + + ref_chosen_logps = [] + ref_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) + ) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) + + # Unnecessary cache clearing to avoid OOM + empty_cache() + self.accelerator.free_memory() + + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() + + self.train_dataset = self.train_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) + self.train_dataset = self.train_dataset.add_column( + name="ref_rejected_logps", column=all_ref_rejected_logps + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + batch_size = self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size + dataloader_params = { + "batch_size": batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + ref_chosen_logps = [] + ref_rejected_logps = [] + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + ref_chosen_logp, ref_rejected_logp = self.compute_ref_log_probs(padded_batch) + ref_chosen_logp, ref_rejected_logp = self.accelerator.gather_for_metrics( + (ref_chosen_logp, ref_rejected_logp) + ) + ref_chosen_logps.append(ref_chosen_logp.cpu()) + ref_rejected_logps.append(ref_rejected_logp.cpu()) + + all_ref_chosen_logps = torch.cat(ref_chosen_logps).float().numpy() + all_ref_rejected_logps = torch.cat(ref_rejected_logps).float().numpy() + + eval_dataset = eval_dataset.add_column(name="ref_chosen_logps", column=all_ref_chosen_logps) + eval_dataset = eval_dataset.add_column(name="ref_rejected_logps", column=all_ref_rejected_logps) + + # Save calculated ref_chosen_logps and ref_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def compute_ref_log_probs(self, batch: dict[str, torch.LongTensor]) -> dict: + """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.""" + compte_ref_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), compte_ref_context_manager: + if self.ref_model is None: + with self.null_ref_context(): + ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True) + else: + ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True) + return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"] + + @staticmethod + def concatenated_inputs( + batch: dict[str, Union[list, torch.LongTensor]], padding_value: int + ) -> dict[str, torch.LongTensor]: + """ + Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt + and completion sequences. + + Args: + batch (`dict[str, Union[list, torch.LongTensor]]`): + A batch of input data. The batch must contain the following keys: + + - `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input IDs. + - `"chosen_input_ids"`: Tensor of shape `(batch_size, chosen_length)` representing the chosen completion input IDs. + - `"rejected_input_ids"`: Tensor of shape `(batch_size, rejected_length)` representing the rejected completion input IDs. + - `"prompt_pixel_values"` (optional): Tensor for pixel values, if available. + - `"prompt_pixel_attention_mask"` (optional): Tensor for pixel attention masks, if available. + + padding_value (`int`): + The padding value to use for the concatenated completion sequences (`chosen_input_ids` and + `rejected_input_ids`). + + Returns: + `dict[str, torch.LongTensor]`: A dictionary containing: + + - `"prompt_input_ids"`: Concatenated prompt input IDs of shape `(2 * batch_size, prompt_length)`. + - `"completion_input_ids"`: Concatenated chosen and rejected completion input IDs of shape `(2 * batch_size, max_completion_length)`. + - `"prompt_attention_mask"`: Concatenated prompt attention masks of shape `(2 * batch_size, prompt_length)`. + - `"completion_attention_mask"`: Concatenated chosen and rejected attention masks of shape `(2 * batch_size, max_completion_length)`. + - `"pixel_values"` (optional): Concatenated pixel values if `"prompt_pixel_values"` are present. + - `"pixel_attention_mask"` (optional): Concatenated pixel attention masks if `"prompt_pixel_attention_mask"` are present. + + Notes: + The completion input IDs and attention masks are padded to the maximum completion length of the chosen + or rejected sequences. + """ + output = {} + + # For the prompt, the input_ids are the same for both the chosen and rejected responses + output["prompt_input_ids"] = torch.cat([batch["prompt_input_ids"], batch["prompt_input_ids"]], dim=0) + output["prompt_attention_mask"] = torch.cat( + [batch["prompt_attention_mask"], batch["prompt_attention_mask"]], dim=0 + ) + if "pixel_values" in batch: + output["pixel_values"] = torch.cat([batch["pixel_values"], batch["pixel_values"]], dim=0) + + if "pixel_attention_mask" in batch: + output["pixel_attention_mask"] = torch.cat( + [batch["pixel_attention_mask"], batch["pixel_attention_mask"]], dim=0 + ) + if "image_sizes" in batch: + output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0) + + # Concatenate the chosen and rejected completions + max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + output["completion_input_ids"] = torch.cat( + ( + pad_to_length(batch["chosen_input_ids"], max_completion_length, pad_value=padding_value), + pad_to_length(batch["rejected_input_ids"], max_completion_length, pad_value=padding_value), + ), + ) + output["completion_attention_mask"] = torch.cat( + ( + pad_to_length(batch["chosen_attention_mask"], max_completion_length, pad_value=0), + pad_to_length(batch["rejected_attention_mask"], max_completion_length, pad_value=0), + ), + ) + + return output + + def dpo_loss( + self, + chosen_logps: torch.FloatTensor, + rejected_logps: torch.FloatTensor, + ref_chosen_logps: torch.FloatTensor, + ref_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Compute the DPO loss for a batch of policy and reference model log probabilities. + + Args: + chosen_logps (`torch.FloatTensor`): + Log probabilities of the model for the chosen responses. Shape: `(batch_size,)`. + rejected_logps (`torch.FloatTensor`): + Log probabilities of the model for the rejected responses. Shape: `(batch_size,)`. + ref_chosen_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the chosen responses. Shape: `(batch_size,)`. + ref_rejected_logps (`torch.FloatTensor`): + Log probabilities of the reference model for the rejected responses. Shape: `(batch_size,)`. + + Returns: + A tuple of three tensors: `(losses, chosen_rewards, rejected_rewards)`. + The losses tensor contains the DPO loss for each example in the batch. + The `chosen_rewards` and `rejected_rewards` tensors contain the rewards for the chosen and rejected + responses, respectively. + """ + device = self.accelerator.device + + # Get the log ratios for the chosen and rejected responses + chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device) + rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device) + + if self.f_divergence_type == FDivergenceType.ALPHA_DIVERGENCE.value: + # The alpha-divergence formula: (1 - u^-alpha) / alpha + # The divergence difference between the chosen and rejected sample is: + # (1 - u[w]^-alpha) / alpha - (1 - u[l]^-alpha) / alpha + # = (u[l]^-alpha - u[w]^-alpha) / alpha + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + alpha_coef = FDivergenceConstants.ALPHA_DIVERGENCE_COEF_DEFAULT + if self.f_divergence_params and FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY in self.f_divergence_params: + alpha_coef = float(self.f_divergence_params[FDivergenceConstants.ALPHA_DIVERGENCE_COEF_KEY]) + logits = (cap_exp(rejected_logratios * -alpha_coef) - cap_exp(chosen_logratios * -alpha_coef)) / alpha_coef + else: + logratios = chosen_logps - rejected_logps + if self.reference_free: + ref_logratios = torch.tensor([0], dtype=logratios.dtype, device=logratios.device) + else: + ref_logratios = ref_chosen_logps - ref_rejected_logps + + logratios = logratios.to(self.accelerator.device) + ref_logratios = ref_logratios.to(self.accelerator.device) + logits = logratios - ref_logratios + + if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value: + # The js-divergence formula: log(2 * u / (1 + u)) + # The divergence difference between the chosen and rejected sample is: + # log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l])) + # = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l])) + # where u[w] and u[l] are the policy/reference probability ratios + # for the chosen and rejected samples, respectively. + logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios) + + # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the + # labels and calculates a conservative DPO loss. + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + + elif self.loss_type == "robust": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + + F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) / (1 - 2 * self.label_smoothing) + + elif self.loss_type == "exo_pair": + # eqn (16) of the EXO paper: https://huggingface.co/papers/2402.00856 + import math + + if self.label_smoothing == 0: + self.label_smoothing = 1e-3 + losses = (self.beta * logits).sigmoid() * ( + F.logsigmoid(self.beta * logits) - math.log(1 - self.label_smoothing) + ) + (-self.beta * logits).sigmoid() * (F.logsigmoid(-self.beta * logits) - math.log(self.label_smoothing)) + + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + + elif self.loss_type == "ipo": + # eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper. + losses = (logits - 1 / (2 * self.beta)) ** 2 + + elif self.loss_type == "bco_pair": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_rewards = self.beta * chosen_logratios + rejected_rewards = self.beta * rejected_logratios + rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach() + self.running.update(rewards) + delta = self.running.mean + losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid( + -(self.beta * rejected_logratios - delta) + ) + + elif self.loss_type == "sppo_hard": + # In the paper (https://huggingface.co/papers/2405.00675), SPPO employs a soft probability approach, + # estimated using the PairRM score. The probability calculation is conducted outside of the trainer class. + # The version described here is the hard probability version, where P in Equation (4.7) of Algorithm 1 is + # set to 1 for the winner and 0 for the loser. + a = chosen_logps - ref_chosen_logps + b = rejected_logps - ref_rejected_logps + losses = (a - 0.5 / self.beta) ** 2 + (b + 0.5 / self.beta) ** 2 + + elif self.loss_type == "nca_pair": + chosen_rewards = (chosen_logps - ref_chosen_logps) * self.beta + rejected_rewards = (rejected_logps - ref_rejected_logps) * self.beta + losses = ( + -F.logsigmoid(chosen_rewards) + - 0.5 * F.logsigmoid(-chosen_rewards) + - 0.5 * F.logsigmoid(-rejected_rewards) + ) + + elif self.loss_type == "aot_pair": + chosen_logratios = chosen_logps - ref_chosen_logps + rejected_logratios = rejected_logps - ref_rejected_logps + chosen_logratios_sorted, _ = torch.sort(chosen_logratios, dim=0) + rejected_logratios_sorted, _ = torch.sort(rejected_logratios, dim=0) + delta = chosen_logratios_sorted - rejected_logratios_sorted + losses = ( + -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * delta) * self.label_smoothing + ) + + elif self.loss_type == "aot": + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logratios_sorted, _ = torch.sort(logratios, dim=0) + ref_logratios_sorted, _ = torch.sort(ref_logratios, dim=0) + delta = logratios_sorted - ref_logratios_sorted + losses = ( + -F.logsigmoid(self.beta * delta) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * delta) * self.label_smoothing + ) + + elif self.loss_type == "apo_zero": + # Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + losses_chosen = 1 - F.sigmoid(self.beta * chosen_logratios) # Increase chosen likelihood + losses_rejected = F.sigmoid(self.beta * rejected_logratios) # Decrease rejected likelihood + losses = losses_chosen + losses_rejected + + elif self.loss_type == "apo_down": + # Eqn (8) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are worse than your model's default output. + # Decrease chosen likelihood and decrease rejected likelihood more + losses_chosen = F.sigmoid(self.beta * chosen_logratios) + losses_rejected = 1 - F.sigmoid(self.beta * (chosen_logratios - rejected_logratios)) + losses = losses_chosen + losses_rejected + + elif self.loss_type == "discopop": + # Eqn (5) of the DiscoPOP paper (https://huggingface.co/papers/2406.08414) + # This loss was discovered with LLM discovery + logratios = chosen_logps - rejected_logps + ref_logratios = ref_chosen_logps - ref_rejected_logps + logits = logratios - ref_logratios + logits = logits * self.beta + # Modulate the mixing coefficient based on the log ratio magnitudes + log_ratio_modulation = torch.sigmoid(logits / self.args.discopop_tau) + logistic_component = -F.logsigmoid(logits) + exp_component = torch.exp(-logits) + # Blend between logistic and exponential component based on log ratio modulation + losses = logistic_component * (1 - log_ratio_modulation) + exp_component * log_ratio_modulation + + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'exo_pair', " + "'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', 'apo_down']" + ) + + chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() + rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() + + return losses, chosen_rewards, rejected_rewards + + def _compute_loss_liger(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]): + unwrapped_model = self.accelerator.unwrap_model(model) + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + + model_kwargs = {} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = unwrapped_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + # 2. Prepare decoder inputs + decoder_input_ids = shift_tokens_right( + concatenated_batch["completion_input_ids"], + unwrapped_model.config.decoder_start_token_id, + ) + # 3. Get decoder outputs + decoder_outputs = unwrapped_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + hidden_states = decoder_outputs.last_hidden_state + + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_encoder_outputs = unwrapped_ref_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = unwrapped_ref_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + elif not self.reference_free: + with self.null_ref_context(): + ref_encoder_outputs = unwrapped_model.get_encoder()( + concatenated_batch["prompt_input_ids"], + attention_mask=concatenated_batch["prompt_attention_mask"], + return_dict=True, + ) + ref_decoder_outputs = unwrapped_model.get_decoder()( + input_ids=decoder_input_ids, + attention_mask=concatenated_batch["completion_attention_mask"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + encoder_attention_mask=concatenated_batch["prompt_attention_mask"], + use_cache=False, + ) + ref_hidden_states = ref_decoder_outputs.last_hidden_state + + labels = concatenated_batch["completion_input_ids"] + loss_mask = completion_attention_mask.bool() + else: + # For decoder-only models + input_ids = torch.cat( + (concatenated_batch["prompt_input_ids"], concatenated_batch["completion_input_ids"]), dim=1 + ) + attention_mask = torch.cat( + (concatenated_batch["prompt_attention_mask"], concatenated_batch["completion_attention_mask"]), + dim=1, + ) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) + + # Flush and truncate + if self.max_length is not None and self.max_length < attention_mask.size(1): + if self.truncation_mode == "keep_start": + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + attention_mask = attention_mask[:, : self.max_length] + input_ids = input_ids[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + elif self.truncation_mode == "keep_end": + # Flush right before truncating left, then flush left + # [[0, 0, x, x, x, x], -> [[0, 0, x, x], + # [0, x, x, x, 0, 0]] [0, x, x, x]] + attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) + else: + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + + # Add logits_to_keep optimization + if self.use_logits_to_keep: + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 + model_kwargs["logits_to_keep"] = logits_to_keep + + model_kwargs["output_hidden_states"] = True + + # Add padding-free training support + if self.padding_free: + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids + else: + model_kwargs["attention_mask"] = attention_mask + + # Get the base model outputs (before LM head) + if hasattr(unwrapped_model, "get_decoder"): + base_model = unwrapped_model.get_decoder() + else: + base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) + + outputs = base_model( + input_ids, + use_cache=False, + **model_kwargs, + ) + hidden_states = outputs.last_hidden_state[:, :-1] + + # Get reference hidden states if needed + ref_hidden_states = None + if not self.reference_free and self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + if hasattr(unwrapped_ref_model, "get_decoder"): + ref_base_model = unwrapped_ref_model.get_decoder() + else: + ref_base_model = getattr( + unwrapped_ref_model, self.args.base_model_attribute_name, unwrapped_ref_model + ) + + ref_outputs = ref_base_model( + input_ids, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + elif not self.reference_free: + if hasattr(unwrapped_model, "get_decoder"): + ref_base_model = unwrapped_model.get_decoder() + else: + ref_base_model = getattr(unwrapped_model, self.args.base_model_attribute_name, unwrapped_model) + with self.null_ref_context(): + ref_outputs = ref_base_model( + input_ids, + attention_mask=attention_mask, + use_cache=False, + **model_kwargs, + ) + ref_hidden_states = ref_outputs.last_hidden_state[:, :-1] + + masked_input_ids = torch.where(loss_mask != 0, input_ids, self.label_pad_token_id) + labels = masked_input_ids[:, 1:] # Shift right for casual LM + + # Get the LM head + lm_head = unwrapped_model.get_output_embeddings() + + # Get reference model weights if needed + ref_weight = None + ref_bias = None + if not self.reference_free: + if self.ref_model is not None: + unwrapped_ref_model = self.accelerator.unwrap_model(self.ref_model) + ref_lm_head = unwrapped_ref_model.get_output_embeddings() + else: + with self.null_ref_context(): + ref_lm_head = unwrapped_model.get_output_embeddings() + ref_weight = ref_lm_head.weight + ref_bias = ref_lm_head.bias if hasattr(ref_lm_head, "bias") else None + + # Compute loss using Liger kernel + loss_output = self.dpo_loss_fn( + lm_head.weight, + hidden_states, + labels, + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + ref_input=ref_hidden_states if not self.reference_free else None, + ref_weight=ref_weight if not self.reference_free else None, + ref_bias=ref_bias if not self.reference_free else None, + ) + ( + loss, + (chosen_logps, rejected_logps, chosen_logits_mean, rejected_logits_mean, nll_loss, *aux_outputs), + ) = loss_output + + output = { + "loss": loss, + "chosen_logps": chosen_logps, + "rejected_logps": rejected_logps, + "mean_chosen_logits": chosen_logits_mean, + "mean_rejected_logits": rejected_logits_mean, + "nll_loss": nll_loss, + "chosen_rewards": aux_outputs[0], + "rejected_rewards": aux_outputs[1], + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False + ): + """ + Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + + Args: + model: + Model to run the forward pass on. + batch: + Batch of input data. + is_ref_model: + Whether this method is being called for the reference model. If `True`, length desensitization is not + applied. + """ + num_examples = batch["prompt_input_ids"].shape[0] + + concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value) + + model_kwargs = {"use_cache": False} + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + # Add the pixel values and attention masks for vision models + if "pixel_values" in concatenated_batch: + model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] + if "pixel_attention_mask" in concatenated_batch: + model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] + if "image_sizes" in concatenated_batch: + model_kwargs["image_sizes"] = concatenated_batch["image_sizes"] + + prompt_input_ids = concatenated_batch["prompt_input_ids"] + prompt_attention_mask = concatenated_batch["prompt_attention_mask"] + completion_input_ids = concatenated_batch["completion_input_ids"] + completion_attention_mask = concatenated_batch["completion_attention_mask"] + if self.is_encoder_decoder: + labels = completion_input_ids + labels[completion_attention_mask == 0] = self.label_pad_token_id + outputs = model( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + labels=labels, # we need the labels for the logits to be returned + **model_kwargs, + ) + logits = outputs.logits + loss_mask = completion_attention_mask.bool() + else: + # Concatenate the prompt and completion inputs + input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1) + attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1) + # Mask the prompt but not the completion for the loss + loss_mask = torch.cat( + (torch.zeros_like(prompt_attention_mask), completion_attention_mask), + dim=1, + ) + + # Flush and truncate + if self.max_length is not None and self.max_length < attention_mask.size(1): + if self.truncation_mode == "keep_start": + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + attention_mask = attention_mask[:, : self.max_length] + input_ids = input_ids[:, : self.max_length] + loss_mask = loss_mask[:, : self.max_length] + elif self.truncation_mode == "keep_end": + # Flush right before truncating left, then flush left + # [[0, 0, x, x, x, x], -> [[0, 0, x, x], + # [0, x, x, x, 0, 0]] [0, x, x, x]] + attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask) + input_ids = input_ids[:, -self.max_length :] + attention_mask = attention_mask[:, -self.max_length :] + loss_mask = loss_mask[:, -self.max_length :] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + else: + raise ValueError( + f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', " + "'keep_start']." + ) + else: + # Flush left to reduce the memory usage + # [[0, 0, x, x, x, x], -> [[x, x, x, x], + # [0, x, x, x, 0, 0]] [x, x, x, 0]] + attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask) + + if self.use_logits_to_keep: + # Compute logits_to_keep based on loss_mask pattern: + # [[0, 0, 0, x, x, x, x], + # [0, 0, 0, x, x, x, 0]] + # ^ start computing logits from here ([:, -(7-3+1):]) + first_compute_index = loss_mask.nonzero(as_tuple=True)[1].min() + logits_to_keep = (loss_mask.shape[1] - first_compute_index).item() + 1 # +1 for the first label + model_kwargs["logits_to_keep"] = logits_to_keep + + model_kwargs["output_hidden_states"] = True + + if self.padding_free: + # Flatten the input_ids, position_ids, and loss_mask + # input_ids = [[a, b, c, 0], -> input_ids = [[a, b, c, d, e, f, g]] + # [d, e, f, g]] position_ids = [[0, 1, 2, 0, 1, 2, 3]] + input_ids = input_ids[attention_mask.bool()].unsqueeze(0) + loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0) + position_ids = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1 + model_kwargs["position_ids"] = position_ids + else: + model_kwargs["attention_mask"] = attention_mask + + outputs = model(input_ids, **model_kwargs) + logits = outputs.logits + + # Offset the logits by one to align with the labels + labels = torch.roll(input_ids, shifts=-1, dims=1) + loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool() + + if self.use_logits_to_keep: + # Align labels with logits + # logits: -, -, [x2, x3, x4, x5, x6] + # ^ --------- ^ after logits[:, :-1, :] + # labels: [y0, y1, y2, y3, y4, y5, y6] + # ^ --------- ^ with logits_to_keep=4, [:, -4:] + # loss_mask: [0, 0, 0, 1, 1, 1, 1] + labels = labels[:, -logits_to_keep:] + loss_mask = loss_mask[:, -logits_to_keep:] + + if logits.shape[:2] != labels.shape[:2]: + # for llava, the returned logits include the image tokens (placed before the text tokens) + seq_len = labels.shape[1] + logits = logits[:, -seq_len:] + + # Compute the log probabilities of the labels + labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later + per_token_logps = selective_log_softmax(logits, labels) + per_token_logps[~loss_mask] = 0 + per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1) + + if self.padding_free: + # Unflatten the per_token_logps (shape: [1, sum_seq_len] -> [batch_size, seq_len]) + batch_size, seq_len = attention_mask.shape + per_token_logps_ = torch.zeros( + batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype + ) + per_token_logps_[attention_mask.bool()] = per_token_logps + per_token_logps = per_token_logps_ + + all_logps = per_token_logps[:, 1:].sum(-1) + + output = {} + + if self.use_weighting: + with torch.no_grad(): + # Eq (2) of the WPO paper: https://huggingface.co/papers/2406.11827 + logprobs = F.log_softmax(logits, dim=-1) + weights_adjustment_factor = torch.logsumexp(2 * logprobs, dim=-1) # same as sum(probs**2) in log space + per_token_logps_adjusted = per_token_logps - weights_adjustment_factor + all_weights = (per_token_logps_adjusted * loss_mask).sum(-1) / loss_mask.sum(-1) + chosen_weights = all_weights[:num_examples] + rejected_weights = all_weights[num_examples:] + output["policy_weights"] = torch.clamp(torch.exp(chosen_weights + rejected_weights), max=1) + + if self.args.rpo_alpha is not None: + # Only use the chosen logits for the RPO loss + chosen_logits = logits[:num_examples, :-1] if not self.is_encoder_decoder else logits[:num_examples] + chosen_labels = labels[:num_examples, :-1] if not self.is_encoder_decoder else labels[:num_examples] + + # Compute the log probabilities of the labels + output["nll_loss"] = F.cross_entropy( + torch.flatten(chosen_logits, end_dim=1), torch.flatten(chosen_labels, end_dim=1), ignore_index=0 + ) + + if self.loss_type == "ipo": + all_logps = all_logps / loss_mask.sum(-1) + + if self.args.ld_alpha is not None and not is_ref_model: + # Compute response lengths based on loss_mask + completion_lengths = loss_mask.sum(dim=1) + + chosen_lengths = completion_lengths[:num_examples] + rejected_lengths = completion_lengths[num_examples:] + public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper + public_lengths = torch.cat([public_lengths, public_lengths], dim=0) + + seq_len = per_token_logps.size(1) + position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps) + + ld_mask = position_ids < public_lengths.unsqueeze(1) + mask = position_ids < completion_lengths.unsqueeze(1) + + front_mask = (ld_mask & mask).float() + rear_mask = (~ld_mask & mask).float() + front_logps = (per_token_logps * front_mask).sum(dim=1) + rear_logps = (per_token_logps * rear_mask).sum(dim=1) + + all_logps = front_logps + self.args.ld_alpha * rear_logps + + output["chosen_logps"] = all_logps[:num_examples] + output["rejected_logps"] = all_logps[num_examples:] + + # Compute the mean logits + if self.padding_free: + # position_ids contains a sequence of range identifiers (e.g., [[0, 1, 2, 0, 1, 2, 3, ...]]). + # There are 2*num_examples ranges in total: the first half corresponds to the chosen tokens, + # and the second half to the rejected tokens. + # To find the start of the rejected tokens, we look for the num_examples+1-th zero in pos_id. + split_idx = (position_ids == 0).nonzero(as_tuple=True)[1][num_examples] + mean_chosen_logits = logits[0, :split_idx][loss_mask[0, :split_idx]].mean() + mean_rejected_logits = logits[0, split_idx:][loss_mask[0, split_idx:]].mean() + else: + mean_chosen_logits = logits[:num_examples][loss_mask[:num_examples]].mean() + mean_rejected_logits = logits[num_examples:][loss_mask[num_examples:]].mean() + + output["mean_chosen_logits"] = mean_chosen_logits + output["mean_rejected_logits"] = mean_rejected_logits + + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the DPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + if self.args.use_liger_loss: + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + chosen_rewards = model_output["chosen_rewards"] + rejected_rewards = model_output["rejected_rewards"] + else: + model_output = self.concatenated_forward(model, batch) + + # if ref_chosen_logps and ref_rejected_logps in batch use them, otherwise use the reference model + if "ref_chosen_logps" in batch and "ref_rejected_logps" in batch: + ref_chosen_logps = batch["ref_chosen_logps"] + ref_rejected_logps = batch["ref_rejected_logps"] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + model_output["chosen_logps"], model_output["rejected_logps"], ref_chosen_logps, ref_rejected_logps + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + if self.args.rpo_alpha is not None: + losses = losses + self.args.rpo_alpha * model_output["nll_loss"] # RPO loss from V3 of the paper + + if self.use_weighting: + losses = losses * model_output["policy_weights"] + + if self.aux_loss_enabled: + losses = losses + self.aux_loss_coef * model_output["aux_loss"] + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean().item() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean().item() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean().item() + metrics[f"{prefix}rewards/margins"] = ( + self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards).mean().item() + ) + metrics[f"{prefix}logps/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["chosen_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logps/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["rejected_logps"]).detach().mean().item() + ) + metrics[f"{prefix}logits/chosen"] = ( + self.accelerator.gather_for_metrics(model_output["mean_chosen_logits"]).detach().mean().item() + ) + metrics[f"{prefix}logits/rejected"] = ( + self.accelerator.gather_for_metrics(model_output["mean_rejected_logits"]).detach().mean().item() + ) + if self.args.rpo_alpha is not None: + metrics[f"{prefix}nll_loss"] = ( + self.accelerator.gather_for_metrics(model_output["nll_loss"]).detach().mean().item() + ) + if self.aux_loss_enabled: + metrics[f"{prefix}aux_loss"] = ( + self.accelerator.gather_for_metrics(model_output["aux_loss"]).detach().mean().item() + ) + + return losses.mean(), metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return loss, metrics + + return loss + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + + # if ref_output in batch use that otherwise use the reference model + if "ref_output" in batch: + ref_output = batch["ref_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + ref_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + else: + ref_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.padding_value, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.padding_value) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + ref_output = pad_to_length(ref_output, self.max_length, self.padding_value) + ref_output_decoded = self.processing_class.batch_decode(ref_output, skip_special_tokens=True) + + return policy_output_decoded, ref_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return loss.detach(), None, None + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip( + random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded + ) + ], + ) + if "wandb" in self.args.report_to and self.accelerator.is_main_process: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float` or `None`, *optional*, defaults to `None`): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent( + """\ + @inproceedings{rafailov2023direct, + title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, + author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, + year = 2023, + booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, + url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, + editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, + }""" + ) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="DPO", + trainer_citation=citation, + paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", + paper_id="2305.18290", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/gkd_config.py b/trl/trainer/gkd_config.py new file mode 100644 index 0000000000000000000000000000000000000000..dc4bb2d2ad3824419fbad335861e3fa0d69a7960 --- /dev/null +++ b/trl/trainer/gkd_config.py @@ -0,0 +1,112 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + +from .sft_config import SFTConfig + + +@dataclass +class GKDConfig(SFTConfig): + """ + Configuration class for [`GKDTrainer`]. + + This class includes only the parameters that are specific to GKD training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`SFTConfig`] documentation. + + Args: + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + lmbda (`float`, *optional*, defaults to `0.5`): + Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy + student-generated outputs). + beta (`float`, *optional*, defaults to `0.5`): + Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When + beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. + max_new_tokens (`int`, *optional*, defaults to `128`): + Maximum number of tokens to generate per completion. + teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): + Model name or path of the teacher model. If `None`, the teacher model will be the same as the model + being trained. + teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model + from a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + seq_kd (`bool`, *optional*, defaults to `False`): + Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT + on teacher-generated output). + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["teacher_model_init_kwargs"] + + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + lmbda: float = field( + default=0.5, + metadata={ + "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " + "student-generated outputs)." + }, + ) + beta: float = field( + default=0.5, + metadata={ + "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " + "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " + "Divergence." + }, + ) + max_new_tokens: int = field( + default=128, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + teacher_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " + "model being trained." + }, + ) + teacher_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "teacher model from a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropouts in `model`."}, + ) + seq_kd: bool = field( + default=False, + metadata={ + "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " + "FT on teacher-generated output)." + }, + ) + + def __post_init__(self): + super().__post_init__() + # check lmbda and beta are in the range [0, 1] + if self.lmbda < 0.0 or self.lmbda > 1.0: + raise ValueError("lmbda must be in the range [0.0, 1.0].") + if self.beta < 0.0 or self.beta > 1.0: + raise ValueError("beta must be in the range [0.0, 1.0].") diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..9ecc40c72b115f81cee480a8bdb2797a13e929b2 --- /dev/null +++ b/trl/trainer/gkd_trainer.py @@ -0,0 +1,358 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import textwrap +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ..models import prepare_deepspeed +from ..models.utils import unwrap_model_for_generation +from .gkd_config import GKDConfig +from .sft_trainer import SFTTrainer +from .utils import ( + DataCollatorForChatML, + disable_dropout_in_model, + empty_cache, + generate_model_card, + get_comet_experiment_url, +) + + +if is_peft_available(): + from peft import PeftConfig + +if is_wandb_available(): + import wandb + + +class GKDTrainer(SFTTrainer): + _tag_names = ["trl", "gkd"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + teacher_model: Union[PreTrainedModel, nn.Module, str] = None, + args: Optional[GKDConfig] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + formatting_func: Optional[Callable] = None, + ): + # add remove_unused_columns=False to the dataclass args + args.remove_unused_columns = False + data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length) + + super().__init__( + model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + peft_config=peft_config, + formatting_func=formatting_func, + ) + + if args.teacher_model_init_kwargs is None: + teacher_model_init_kwargs = {} + elif not isinstance(teacher_model, str): + raise ValueError( + "You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated." + ) + else: + teacher_model_init_kwargs = args.teacher_model_init_kwargs + teacher_model_init_kwargs["torch_dtype"] = ( + teacher_model_init_kwargs["torch_dtype"] + if teacher_model_init_kwargs["torch_dtype"] in ["auto", None] + else getattr(torch, teacher_model_init_kwargs["torch_dtype"]) + ) + + if isinstance(teacher_model, str): + teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(self.model) + + if self.is_deepspeed_enabled: + self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + + self.lmbda = args.lmbda + self.beta = args.beta + self.temperature = args.temperature + self.seq_kd = args.seq_kd + + self.generation_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + do_sample=True, + top_k=0, + use_cache=False if args.gradient_checkpointing else True, + pad_token_id=self.processing_class.pad_token_id, + ) + # Set custom EOS tokens if they are specified by the model's generation + # config. This is important for models with the Llama 3 chat template, + # which use special tokens <|eot_id|> and <|eom_id|> to mark the end of + # turns or messages. + if ( + hasattr(self.model.generation_config, "eos_token_id") + and self.model.generation_config.eos_token_id is not None + ): + self.generation_config.eos_token_id = self.model.generation_config.eos_token_id + + @staticmethod + def generalized_jsd_loss( + student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean" + ): + """ + Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1) + of https://huggingface.co/papers/2306.13649 for the definition. + + Args: + student_logits: Tensor of shape (batch_size, sequence_length, vocab_size) + teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size) + labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss + beta: Interpolation coefficient between 0 and 1 (default: 0.5) + temperature: Softmax temperature (default: 1.0) + reduction: Specifies the reduction to apply to the output (default: 'batchmean') + + Returns: + loss: Scalar tensor with the generalized JSD loss + """ + + # Apply temperature scaling + student_logits = student_logits / temperature + teacher_logits = teacher_logits / temperature + + # Compute log probabilities for student and probabilities for teacher + student_log_probs = F.log_softmax(student_logits, dim=-1) + teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) + + if beta == 0: + jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + # Compute the log of the mixture distribution + # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture + beta = torch.tensor(beta, dtype=student_log_probs.dtype) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]), + dim=0, + ) + + # Compute KL divergences using F.kl_div + # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) + + # Compute the Generalized Jensen-Shannon Divergence + jsd = beta * kl_teacher + (1 - beta) * kl_student + + # Masking + if labels is not None: + mask = labels != -100 + jsd = jsd[mask] + + # Apply reduction + if reduction == "batchmean": + return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1)) + elif reduction == "sum": + return jsd.sum() + elif reduction == "mean": + return jsd.mean() + else: + return jsd + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + # compute student output + outputs_student = model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + # compute teacher output in eval mode + self.teacher_model.eval() + with torch.no_grad(): + outputs_teacher = self.teacher_model( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + ) + + # slice the logits for the generated tokens using the inputs["prompts"] lengths + prompt_lengths = inputs["prompts"].shape[1] + shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :] + shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :] + shifted_labels = inputs["labels"][:, prompt_lengths:] + + # compute loss + loss = self.generalized_jsd_loss( + student_logits=shifted_student_logits, + teacher_logits=shifted_teacher_logits, + labels=shifted_labels, + beta=self.beta, + ) + + # empty cache + empty_cache() + + # Return loss + return (loss, outputs_student) if return_outputs else loss + + @staticmethod + def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None): + # Generate output with respect to the prompt only + generated_outputs = model.generate( + input_ids=inputs["prompts"], + attention_mask=inputs.get("prompt_attention_mask", None), + generation_config=generation_config, + return_dict_in_generate=True, + ) + + # Get the generated token IDs + generated_tokens = generated_outputs.sequences + # Calculate new attention mask + new_attention_mask = torch.ones_like(generated_tokens) + new_labels = generated_tokens.clone() + + # If there's pad_token_id, set attention mask to 0 for padding tokens + if pad_token_id is not None: + new_labels[new_labels == pad_token_id] = -100 + new_attention_mask[generated_tokens == pad_token_id] = 0 + + return generated_tokens, new_attention_mask, new_labels + + def training_step( + self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: + """ + Perform a training step for the Generalized Knowledge Distillation (GKD) model. + + This method implements the on-policy learning approach described in the GKD paper. + With probability `self.lmbda`, it generates new responses using the student model, + which are then used for training instead of the original inputs. + """ + if self.seq_kd: + with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model: + new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels + if random.random() <= self.lmbda: + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs( + unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id + ) + inputs["input_ids"] = new_input_ids + inputs["attention_mask"] = new_attention_mask + inputs["labels"] = new_labels + + loss = super().training_step(model, inputs, num_items_in_batch) + return loss + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @inproceedings{agarwal2024on-policy, + title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, + author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=3zKtaqxLhW}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="GKD", + trainer_citation=citation, + paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", + paper_id="2306.13649", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..f3dc2326a41801677093b7391a69d511a85d756a --- /dev/null +++ b/trl/trainer/grpo_config.py @@ -0,0 +1,572 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional, Union + +import transformers +from packaging import version +from transformers import TrainingArguments + + +@dataclass +class GRPOConfig(TrainingArguments): + r""" + Configuration class for the [`GRPOTrainer`]. + + This class includes only the parameters that are specific to GRPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model and reference model + + model_init_kwargs (`str`, `dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`GRPOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `False`): + Whether to disable dropout in the model. This is useful for training with a reference model, as it + prevents the model from generating different logprobs for the same input. + + > Parameters that control the data preprocessing + + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that + requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. + num_generations (`int` or `None`, *optional*, defaults to `8`): + Number of generations per prompt to sample. The effective batch size (num_processes * + per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value. + max_completion_length (`int` or `None`, *optional*, defaults to `256`): + Maximum length of the generated completion. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible + with vLLM generation. + shuffle_dataset (`bool`, *optional*, defaults to `True`): + Whether to shuffle the training dataset. + + > Parameters that control generation + + generation_batch_size: (`int` or `None`, *optional*, defaults to `None`): + Batch size to use for generation. If `None`, it defaults to the effective training batch size: + `per_device_train_batch_size * num_processes * gradient_accumulation_steps`. + steps_per_generations: (`int` or `None`, *optional*, defaults to `None`): + Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps. + temperature (`float`, defaults to `1.0`): + Temperature for sampling. The higher the temperature, the more random the completions. + top_p (`float`, *optional*, defaults to `1.0`): + Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to + `1.0` to consider all tokens. + top_k (`int` or `None`, *optional*, defaults to `None`): + Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is + disabled and all tokens are considered. + min_p (`float` or `None`, *optional*, defaults to `None`): + Minimum token probability, which will be scaled by the probability of the most likely token. It must be a + value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range. + repetition_penalty (`float`, *optional*, defaults to `1.0`): + Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far. + Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat + tokens. + cache_implementation (`str` or `None`, *optional*, defaults to `None`): + Implementation of the cache method for faster generation when use_vllm is set to False. + + > Parameters that control generation acceleration powered by vLLM + + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation + instead of the default model.generate(). Requires `vllm` to be installed. + vllm_mode (`str`, *optional*, defaults to `"server"`): + Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or + `"colocate"`. + + - `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM + server is running (start with `trl vllm-serve`). + - `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a + separate server but may cause resource contention with training. + vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`): + Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled. + + > Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`): + Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and + `vllm_server_port` are ignored. + vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`): + Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_port (`int`, *optional*, defaults to `8000`): + Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided. + vllm_server_timeout (`float`, *optional*, defaults to `240.0`): + Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the + timeout, a `ConnectionError` is raised. + + > Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + + vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`): + Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_gpu_memory_utilization` flag. + vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`): + Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to + `"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when + launching the vLLM server via the `--vllm_tensor_parallel_size` flag. + + > Parameters that control the training + + beta (`float`, *optional*, defaults to `0.0`): + KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and improving + training speed. + num_iterations (`int`, *optional*, defaults to `1`): + Number of iterations per batch (denoted as μ in the algorithm). + epsilon (`float`, *optional*, defaults to `0.2`): + Epsilon value for clipping. + delta: (`float` or `None`, *optional*, defaults to `None`): + Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard + GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This method is introduced in + the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). + epsilon_high (`float` or `None`, *optional*, defaults to `None`): + Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound + specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`. + reward_weights (`list[float]` or `None`, *optional*, defaults to `None`): + Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are + weighted equally with weight `1.0`. + scale_rewards (`bool`, *optional*, defaults to `True`): + Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards + are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is + applied. The [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) recommends not scaling the rewards, + as scaling by the standard deviation introduces a question-level difficulty bias. + loss_type (`str`, *optional*, defaults to `"bnpo"`): + Specifies the loss formulation to use. Supported values are: + + - `"grpo"`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to + length bias—this approach tends to prefer shorter completions with positive advantages and longer ones + with negative advantages. + - `"bnpo"`: Aggregates token-level losses by normalizing number of active token in the local batch. + Note that normalization is performed over the local batch only, so results may slightly vary depending + on the local batch size, despite a constant effective batch size. When using + `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. + - `"dr_grpo"`: Aggregates token-level losses by normalizing with a global constant. This method was + introduced in the [Dr. GRPO paper](https://huggingface.co/papers/2503.20783) to eliminate length bias. + The value of the constant corresponds to `max_completion_length`. + mask_truncated_completions (`bool`, *optional*, defaults to `False`): + When enabled, truncated completions are excluded from the loss calculation, preventing them from being + incorrectly penalized and introducing noise during training. According to the + [DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability. + sync_ref_model (`bool`, *optional*, defaults to `False`): + Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using + the `ref_model_mixup_alpha` parameter. This synchronization originates from the + [TR-DPO](https://huggingface.co/papers/2404.09656) paper. + ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`): + α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix + between the current policy and the previous reference policy during updates. The reference policy is + updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you + must set `sync_ref_model=True`. + ref_model_sync_steps (`int`, *optional*, defaults to `512`): + τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how + frequently the current policy is synchronized with the reference policy. To use this parameter, you must + set `sync_ref_model=True`. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use the Liger GRPO loss. + + > Parameters that control the logging + + log_completions (`bool`, *optional*, defaults to `False`): + Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is + installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`. + num_completions_to_print (`int` or `None`, *optional*, defaults to `None`): + Number of completions to print with `rich`. If `None`, all completions are logged. + wandb_log_unique_prompts (`bool`, *optional*, defaults to `False`): + Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, all + prompts are logged. + """ + + if version.parse(transformers.__version__) >= version.parse("4.51.0"): + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + # Parameters that control the model and reference model + model_init_kwargs: Optional[Union[dict, str]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=False, + metadata={ + "help": "Whether to disable dropout in the model. This is useful for training with a reference model, as " + "it prevents the model from generating different logprobs for the same input." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on + # additional columns to compute the reward + remove_unused_columns: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) + num_generations: Optional[int] = field( + default=8, + metadata={ + "help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size " + "* gradient_accumulation_steps) must be evenly divisible by this value." + }, + ) + max_completion_length: Optional[int] = field( + default=256, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + shuffle_dataset: Optional[bool] = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + + # Parameters that control generation + generation_batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: " + "`per_device_train_batch_size * num_processes * gradient_accumulation_steps`." + }, + ) + steps_per_generation: Optional[int] = field( + default=None, + metadata={ + "help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps." + }, + ) + temperature: float = field( + default=1.0, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: Optional[int] = field( + default=None, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " + "top-k-filtering is disabled and all tokens are considered." + }, + ) + min_p: Optional[float] = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + cache_implementation: Optional[str] = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for " + "generation instead of the default model.generate(). Requires `vllm` to be installed." + }, + ) + vllm_server_base_url: Optional[str] = field( + default=None, + metadata={ + "help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` " + "and `vllm_server_port` are ignored." + }, + ) + vllm_mode: str = field( + default="server", + metadata={ + "help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or " + "`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure a " + "TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same " + "process and share the training GPUs. This avoids the need for a separate server but may cause resource " + "contention with training." + }, + ) + vllm_guided_decoding_regex: Optional[str] = field( + default=None, + metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + ) + + # Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."}, + ) + vllm_server_timeout: float = field( + default=240.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + + # Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`) + vllm_gpu_memory_utilization: float = field( + default=0.3, + metadata={ + "help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_gpu_memory_utilization` flag." + }, + ) + vllm_tensor_parallel_size: int = field( + default=1, + metadata={ + "help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set " + "to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when " + "launching the vLLM server via the `--vllm_tensor_parallel_size` flag." + }, + ) + + # Parameters that control the training + beta: float = field( + default=0.0, + metadata={ + "help": "KL coefficient. If `0.0` (default), the reference model is not loaded, reducing memory usage and " + "improving training speed." + }, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Epsilon value for clipping."}, + ) + delta: Optional[float] = field( + default=None, + metadata={ + "help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` " + "(default), standard GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This " + "method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)." + }, + ) + epsilon_high: Optional[float] = field( + default=None, + metadata={ + "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." + }, + ) + reward_weights: Optional[list[float]] = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) + scale_rewards: bool = field( + default=True, + metadata={ + "help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), " + "the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no " + "scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard " + "deviation introduces a question-level difficulty bias." + }, + ) + loss_type: str = field( + default="bnpo", + metadata={ + "help": "Specifies the loss formulation to use. Supported values are `grpo`, `bnpo`, and `dr_grpo`. " + "`'grpo'`: Aggregates token-level losses by normalizing over sequence length. Not recommended due to " + "length bias—this approach tends to prefer shorter completions with positive advantages and longer ones " + "with negative advantages. " + "`'bnpo'`: Aggregates token-level losses by normalizing number of active token in the local batch. " + "Note that normalization is performed over the local batch only, so results may slightly vary depending " + "on the local batch size, despite a constant effective batch size. When using " + "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. " + "`'dr_grpo'`: Aggregates token-level losses by normalizing with a global constant. This method was " + "introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to " + "`max_completion_length`." + }, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={ + "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " + "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " + "a good practice for training stability." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use the Liger GRPO loss."}, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + num_completions_to_print: Optional[int] = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + wandb_log_unique_prompts: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, " + "all prompts are logged." + }, + ) + + def __post_init__(self): + super().__post_init__() + + num_processes = self.world_size + # The current default effective batch size + if self.generation_batch_size is not None and self.steps_per_generation is not None: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time" + ) + + if self.steps_per_generation is None: + self.steps_per_generation = self.gradient_accumulation_steps + + if self.generation_batch_size is None: + self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation + + if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0: + raise ValueError( + f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size " + f"({self.per_device_train_batch_size * num_processes})." + ) + + self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes) + + # Check if the effective batch size can be divided by the number of generations + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + possible_values = [ + n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0 + ] + + if self.num_generations not in possible_values: + raise ValueError( + f"The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x " + f"{self.steps_per_generation}) must be evenly divisible by the number of generations per " + f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for " + f"the number of generations are: {possible_values}." + ) + if self.eval_strategy != "no": + global_eval_batch_size = self.per_device_eval_batch_size * num_processes + possible_values = [ + n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0 + ] + if self.num_generations not in possible_values: + raise ValueError( + f"The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be " + f"evenly divisible by the number of generations per prompt ({self.num_generations}). Given the " + "current global eval batch size, the valid values for the number of generations are: " + f"{possible_values}." + ) + if self.delta is not None and self.use_liger_loss: + raise ValueError("Liger loss does not support two-sided GRPO loss yet.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4b5138eefcb0099a61d5c34bf947539c57e3ab72 --- /dev/null +++ b/trl/trainer/grpo_trainer.py @@ -0,0 +1,1529 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +import warnings +from collections import defaultdict, deque +from collections.abc import Sized +from contextlib import nullcontext +from functools import partial +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import datasets +import torch +import torch.utils.data +import transformers +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from packaging import version +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.utils.data import DataLoader, Sampler +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + is_wandb_available, +) +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled +from transformers.trainer_utils import seed_worker +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available + +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ..extras.profiling import profiling_context, profiling_decorator +from ..extras.vllm_client import VLLMClient +from ..import_utils import is_liger_kernel_available, is_vllm_available +from ..models import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation +from ..models.utils import _ForwardRedirection +from .callbacks import SyncRefModelCallback +from .grpo_config import GRPOConfig +from .utils import ( + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + pad, + print_prompt_completions_sample, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + +if is_vllm_available(): + from vllm import LLM, SamplingParams + from vllm.sampling_params import GuidedDecodingParams + +if is_wandb_available(): + import wandb + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] + + +class RepeatSampler(Sampler): + """ + Sampler that repeats the indices of a dataset in a structured manner. + + Args: + data_source (`Sized`): + Dataset to sample from. + mini_repeat_count (`int`): + Number of times to repeat each index per batch. + batch_size (`int`, *optional*, defaults to `1`): + Number of unique indices per batch. + repeat_count (`int`, *optional*, defaults to `1`): + Number of times to repeat the full sampling process. + shuffle (`bool`, *optional*, defaults to `True`): + Whether to shuffle the dataset. + seed (`int` or `None`, *optional*, defaults to `None`): + Random seed for reproducibility (only affects this sampler). + + Example: + ```python + >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) + >>> list(sampler) + [4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + 4, 4, 3, 3, 0, 0, + + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6, + 1, 1, 2, 2, 6, 6] + ``` + + ```txt + mini_repeat_count = 3 + - - - + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | + repeat_count = 2 + 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | + 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | + 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | + --------- --------- --------- --------- + --------- --------- --------- --------- + --------- --------- --------- --------- + batch_size = 12 + ``` + """ + + def __init__( + self, + data_source: Sized, + mini_repeat_count: int, + batch_size: int = 1, + repeat_count: int = 1, + shuffle: bool = True, + seed: Optional[int] = None, + ): + self.data_source = data_source + self.mini_repeat_count = mini_repeat_count + self.batch_size = batch_size + self.repeat_count = repeat_count + self.num_samples = len(data_source) + self.shuffle = shuffle + self.seed = seed + + if shuffle: + self.generator = torch.Generator() # Create a local random generator + if seed is not None: + self.generator.manual_seed(seed) + + def __iter__(self): + if self.shuffle: + # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) + indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() + else: + indexes = list(range(self.num_samples)) + + # [2, 4, 3, 1, 0, 6, 5] + # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) + indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] + + # [[2, 4, 3], [1, 0, 6], [5]] + # -> [[2, 4, 3], [1, 0, 6]] + indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] + + for chunk in indexes: + for _ in range(self.repeat_count): + for index in chunk: + for _ in range(self.mini_repeat_count): + yield index + + def __len__(self) -> int: + return self.num_samples * self.mini_repeat_count * self.repeat_count + + +# torch.nanstd doesn't exist, so we define it here +def nanstd(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute the standard deviation of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): + Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: + Standard deviation of the tensor, ignoring NaNs. + """ + variance = torch.nanmean((tensor - torch.nanmean(tensor, keepdim=True)) ** 2) # Compute variance ignoring NaNs + count = torch.sum(~torch.isnan(tensor)) # Count of non-NaN values + variance *= count / (count - 1) # Bessel's correction + return torch.sqrt(variance) + + +def split_tensor_dict( + tensor_dict: dict[str, Optional[torch.Tensor]], num_chunks: int +) -> list[dict[str, Optional[torch.Tensor]]]: + """ + Splits a dictionary of tensors along the first dimension into `num_chunks` equal parts. + + Example: + >>> x = torch.arange(12).reshape(6, 2) + >>> y = torch.arange(6).reshape(6, 1) + >>> tensor_dict = {"x": x, "y": y} + >>> split_tensor_dict(tensor_dict, 3) + [ + {"x": tensor([[0, 1], [2, 3]]), "y": tensor([[0], [1]])}, + {"x": tensor([[4, 5], [6, 7]]), "y": tensor([[2], [3]])}, + {"x": tensor([[ 8, 9], [10, 11]]), "y": tensor([[4], [5]])} + ] + """ + first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) + chunk_size = first_tensor.shape[0] // num_chunks + return [ + { + key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None + for key, tensor in tensor_dict.items() + } + for i in range(num_chunks) + ] + + +def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]: + """ + Shuffles a dictionary of tensors along the first dimension in unison. + + Example: + >>> x = torch.arange(6).reshape(3, 2) + >>> y = torch.arange(3).reshape(3, 1) + >>> tensor_dict = {"x": x, "y": y} + >>> shuffle_tensor_dict(tensor_dict) + {'x': tensor([[2, 3], + [0, 1], + [4, 5]]), + 'y': tensor([[1], + [0], + [2]])} + """ + first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None) + batch_size = first_tensor.shape[0] + permutation = torch.randperm(batch_size) + return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()} + + +def nanmin(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: Minimum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + if torch.isnan(tensor).all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + return torch.min(tensor[~torch.isnan(tensor)]) + + +def nanmax(tensor: torch.Tensor) -> torch.Tensor: + """ + Compute the maximum value of a tensor, ignoring NaNs. This function only supports 1D tensors. + + Args: + tensor (`torch.Tensor`): Input tensor of shape `(N,)`. + + Returns: + `torch.Tensor`: Maximum value of the tensor, ignoring NaNs. Returns NaN if all values are NaN. + """ + if torch.isnan(tensor).all(): + return torch.tensor(float("nan"), dtype=tensor.dtype, device=tensor.device) + return torch.max(tensor[~torch.isnan(tensor)]) + + +def identity(x): + """Do we really need docs for this?""" + return x + + +class GRPOTrainer(Trainer): + """ + Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the + paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). + + Example: + + ```python + from datasets import load_dataset + from trl import GRPOTrainer + + dataset = load_dataset("trl-lib/tldr", split="train") + + def reward_func(completions, **kwargs): + # Dummy reward function that rewards completions with more unique letters. + return [float(len(set(completion))) for completion in completions] + + trainer = GRPOTrainer( + model="Qwen/Qwen2-0.5B-Instruct", + reward_funcs=reward_func, + train_dataset=dataset, + ) + + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): + Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward + functions with the prompts and completions and sum the rewards. Can be either: + + - A single reward function, such as: + - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the + keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. + - A custom reward function: The function is provided with the prompts and the generated completions, + plus any additional columns in the dataset. It should return a list of rewards. Custom reward + functions can also return None when the reward is not applicable to those samples. This is useful for + multi-task training where different reward functions apply to different types of samples. When a + reward function returns None for a sample, that reward function is excluded from the reward + calculation for that sample. For more details, see + [Using a custom reward function](#using-a-custom-reward-function). + - A list of reward functions, where each item can independently be any of the above types. Mixing different + types within the list (e.g., a string model ID and a custom reward function) is allowed. + args ([`GRPOConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is + ignored. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. The padding side must be set to "left". If `None`, the + processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. A + padding token, `processing_class.pad_token`, must be set. If the processing class has not set a padding + token, `processing_class.eos_token` will be used as the default. + reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): + Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: + + - A single processing class: Used when `reward_funcs` contains only one reward function. + - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. + If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is + `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. + For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), + the corresponding entries in `reward_processing_classes` are ignored. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "grpo"] + + def __init__( + self, + model: Union[str, PreTrainedModel], + reward_funcs: Union[RewardFunc, list[RewardFunc]], + args: Optional[GRPOConfig] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, + reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + peft_config: Optional["PeftConfig"] = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = GRPOConfig(f"{model_name}-GRPO") + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + + if peft_config is not None: + if not is_peft_available(): + raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.") + model = get_peft_model(model, peft_config) + + # Enable gradient checkpointing if requested + if args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError("The number of reward processing classes must match the number of reward functions.") + + for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_vllm = args.use_vllm + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.use_liger_loss = args.use_liger_loss + self.loss_type = args.loss_type + self.scale_rewards = args.scale_rewards + self.mask_truncated_completions = args.mask_truncated_completions + + # Datasets + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." + ) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=identity, # No data collation is needed in GRPO + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + ) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_deepspeed_zero3_enabled() or self.is_fsdp_enabled: + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + + # Disable dropout in the models + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Liger loss + if self.use_liger_loss: + if not is_liger_kernel_available(): + raise ImportError( + "Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`." + ) + # redirect the model.module forward to the model forward to ensure pre-forward hooks are called + self._forward_redirection = _ForwardRedirection() + + self.liger_grpo_loss = LigerFusedLinearGRPOLoss( + beta=self.beta, + epsilon_low=self.epsilon_low, + epsilon_high=self.epsilon_high, + temperature=self.temperature, + use_ref_model=self.beta != 0.0, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + ) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.wandb_log_unique_prompts = args.wandb_log_unique_prompts + self.num_completions_to_print = args.num_completions_to_print + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.steps_per_generation + self._textual_logs = { + "prompt": deque(maxlen=maxlen), + "completion": deque(maxlen=maxlen), + "rewards": defaultdict(lambda: deque(maxlen=maxlen)), + "advantages": deque(maxlen=maxlen), + } + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + if args.vllm_server_base_url is not None: + base_url = args.vllm_server_base_url + else: + base_url = f"http://{args.vllm_server_host}:{args.vllm_server_port}" + self.vllm_client = VLLMClient(base_url=base_url, connection_timeout=args.vllm_server_timeout) + self.vllm_client.init_communicator() + + elif self.vllm_mode == "colocate": + # Make sure vllm_tensor_parallel_size group size evenly divides the world size - each group should have + # the same number of ranks + if not self.accelerator.num_processes % self.vllm_tensor_parallel_size == 0: + raise ValueError( + f"vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size " + f"({self.accelerator.num_processes}) evenly." + ) + + if self.vllm_tensor_parallel_size > 1: + # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks. + # For example, if world_size=8 and vllm_tensor_parallel_size=2 → groups: [0,1], [2,3], [4,5], [6,7] + self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration( + [ + list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size)) + for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size) + ] + ) + + self.llm = LLM( + model=model.name_or_path, + tensor_parallel_size=args.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + max_num_seqs=self.args.per_device_train_batch_size + * self.vllm_tensor_parallel_size + * self.args.gradient_accumulation_steps, + max_model_len=self.max_prompt_length + self.max_completion_length, + distributed_executor_backend="external_launcher", + # Feed identical seed for tp groups to ensure sampling results are the same across workers + seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + max_num_batched_tokens=4096, + ) + + # vLLM specific sampling arguments + self.guided_decoding_regex = args.vllm_guided_decoding_regex + + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + self.generation_config = GenerationConfig( + max_new_tokens=self.max_completion_length, + do_sample=True, + pad_token_id=processing_class.pad_token_id, + bos_token_id=processing_class.bos_token_id, + eos_token_id=processing_class.eos_token_id, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + repetition_penalty=self.repetition_penalty, + cache_implementation=args.cache_implementation, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + elif self.is_fsdp_enabled: + self.ref_model = prepare_fsdp(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + # set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp + self.reward_funcs[i] = self.accelerator.prepare_model( + reward_func, evaluation_mode=True, device_placement=True + ) + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs. + # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. + # Instead, we set them to the columns expected by the `training_step` method, hence the override. + if self._signature_columns is None: + self._signature_columns = ["prompt"] + + # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. + # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an + # *generation* batch (i.e., `per_device_batch_size × steps_per_generation`). This allows us to generate completions + # once every steps_per_generation step—rather than once per accumulation step—which is significantly more + # efficient. The only change from the original implementation is multiplying the batch size by + # `steps_per_generation`. Thus, `_prepare_inputs` is called with this *generation* batch, and it handles the + # splitting internally. + # Maintenance note: This method is a copy-paste of the original `Trainer.get_train_dataloader` with only one line + # modification. As a result, some parts of the method aren't relevant to GRPO, but we keep them to stay one line + # apart from the super method, ensuring easier maintenance in the future. + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + dataloader_params = { + "batch_size": self._train_batch_size * self.args.steps_per_generation, # < this is the change + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + if version.parse(transformers.__version__) >= version.parse("4.52.0"): + # from transformers 4.52.0, the `seed_worker` requires the `num_workers` and `rank` arguments + dataloader_params["worker_init_fn"] = partial( + seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index + ) + else: + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler: + # Returns a sampler that + # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are + # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt + # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies + # in group formation. + # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to + # _prepare_inputs to see how the generations are stored and reused. + + # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the + # second row shows the second sampled batch, and so on. + # + # | GPU 0 | GPU 1 | + # + # global_step step <-───> num_generations=2 + # <-───────> per_device_train_batch_size=3 + # grad_accum ▲ ▲ 0 0 0 0 1 1 2 2 <- Generate for the first `steps_per_generation` (prompts 0 to 11); store the completions; use the first slice to compute the loss + # =2 ▼ | 0 1 3 3 4 4 5 5 <- Take the stored generations and use the second slice to compute the loss + # | + # | 1 2 6 6 7 7 8 8 <- Take the stored generations and use the third slice to compute the loss + # steps_per_gen=4 ▼ 1 3 9 9 10 10 11 11 <- Take the stored generations and use the fourth slice to compute the loss + # + # 2 4 12 12 13 13 14 14 <- Generate for the second `steps_per_generation` (prompts 12 to 23); store the completions; use the first slice to compute the loss + # 2 5 15 15 16 16 17 17 <- Take the stored generations and use the second slice to compute the loss + # ... + if dataset is None: + dataset = self.train_dataset + return RepeatSampler( + data_source=dataset, + mini_repeat_count=self.num_generations, + batch_size=self.args.generation_batch_size // self.num_generations, + repeat_count=self.num_iterations * self.args.steps_per_generation, + shuffle=self.shuffle_dataset, + seed=self.args.seed, + ) + + def _get_eval_sampler(self, eval_dataset) -> Sampler: + # See _get_train_sampler for an explanation of the sampler. + return RepeatSampler( + data_source=eval_dataset, + mini_repeat_count=self.num_generations, + seed=self.args.seed, + ) + + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + # Ensure use_cache is disabled + model.config.use_cache = False + + # Enable gradient checkpointing on the base model for PEFT + if is_peft_model(model): + model.base_model.gradient_checkpointing_enable() + # Enable gradient checkpointing for non-PEFT models + else: + model.gradient_checkpointing_enable() + + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + model.enable_input_require_grads() + + return model + + @profiling_decorator + def _get_last_hidden_state(self, unwrapped_model, input_ids, attention_mask, logits_to_keep=None): + if is_peft_model(unwrapped_model): + unwrapped_model = unwrapped_model.base_model.model + last_hidden_state = unwrapped_model.model(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state + last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) + if logits_to_keep is not None: + last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) + return last_hidden_state + + # Get the per-token log probabilities for the completions for the model and the reference model + @profiling_decorator + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor: + batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak + all_logps = [] + for i in range(0, input_ids.size(0), batch_size): + input_ids_batch = input_ids[i : i + batch_size] + attention_mask_batch = attention_mask[i : i + batch_size] + + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model( + input_ids=input_ids_batch, attention_mask=attention_mask_batch, logits_to_keep=logits_to_keep + 1 + ).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + input_ids_batch = input_ids_batch[:, -logits_to_keep:] + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens + all_logps.append(logps) + return torch.cat(all_logps, dim=0) + + def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): + """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" + if visited is None: + visited = set() + + for child_name, child_module in module.named_children(): + child_prefix = f"{prefix}.{child_name}" if prefix else child_name + self._sync_fsdp_params_to_vllm( + child_module, prefix=child_prefix, visited=visited + ) # recurse into the child + + if isinstance(module, FSDP): + with FSDP.summon_full_params(module, recurse=False, writeback=False): + for param_name, param in module.named_parameters(): + full_name = f"{prefix}.{param_name}" if prefix else param_name + for extra in ("_fsdp_wrapped_module.", "_checkpoint_wrapped_module."): + full_name = full_name.replace(extra, "") + + if full_name in visited: + continue # skip FSDP subtrees already traversed + visited.add(full_name) + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(full_name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(full_name, param.data)]) + + @profiling_decorator + def _move_model_to_vllm(self): + # For DeepSpeed ZeRO-3 and FSDP, we need to gather all parameters before operations + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 + if zero_stage_3: + import deepspeed + + gather_if_zero3 = deepspeed.zero.GatheredParameters + else: + gather_if_zero3 = nullcontext + + if is_peft_model(self.model): + # With PEFT and FSDP/DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as + # merging adapters in a sharded manner is not supported. + # TODO: does this work with FSDP? + with gather_if_zero3(list(self.model.parameters())): + self.model.merge_adapter() + + # Update vLLM weights while parameters are gathered + if self.is_fsdp_enabled: # note if using FSDP, gather_if_zero3 is nullcontext + # Update vLLM weights while parameters are gathered + # For PEFT with FSDP we need to use the memory efficient post-order traversal + self._sync_fsdp_params_to_vllm(self.model) + else: + # DeepSpeed ZeRO-3 with PEFT + for name, param in self.model.named_parameters(): + # When using PEFT, we need to recover the original parameter name and discard some parameters + name = name.removeprefix("base_model.model.").replace(".base_layer", "") + if self.model.prefix in name: + continue + # When module to save, remove its prefix and discard the original module + if "original_module" in name: + continue + name = name.replace("modules_to_save.default.", "") + + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + # Unmerge adapters while parameters are still gathered + self.model.unmerge_adapter() + # Parameters will automatically be repartitioned when exiting the context + else: + # For non-PEFT models, simply gather (if needed) and update each parameter individually. + if self.is_fsdp_enabled: + self._sync_fsdp_params_to_vllm(self.model) # use memory-efficient post-order traversal for FSDP + else: + for name, param in self.model.named_parameters(): + with gather_if_zero3([param]): + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.update_named_param(name, param.data) + elif self.vllm_mode == "colocate": + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights([(name, param.data)]) + + # Reset cache on vLLM + if self.vllm_mode == "server" and self.accelerator.is_main_process: + self.vllm_client.reset_prefix_cache() + elif self.vllm_mode == "colocate": + self.llm.reset_prefix_cache() + + @profiling_decorator + def _prepare_inputs( + self, generation_batch: dict[str, Union[torch.Tensor, Any]] + ) -> dict[str, Union[torch.Tensor, Any]]: + # Prepares inputs for model training/evaluation by managing completion generation and batch handling. + # During training: + # - Receives the local generation batch (Per-GPU batch size × steps per generation) + # from the modified training dataloader instead of the standard local batch + # - Generates completions once for the entire generation batch and splits it into batches of size + # `per_device_train_batch_size` + # - Buffers these completions and returns the appropriate slice for the current accumulation step + # - Optimizes by regenerating completions only periodically (every steps_per_generation * num_iterations) + # During evaluation: + # - The input is treated as a standard local batch (no accumulation, no multiple iterations) + # - Completions are generated for each batch without buffering or reuse + # Returns a single local batch in both cases. + + mode = "train" if self.model.training else "eval" + if mode == "train": + generate_every = self.args.steps_per_generation * self.num_iterations + if self._step % generate_every == 0 or self._buffered_inputs is None: + # self._buffered_inputs=None can occur when resuming from a checkpoint + generation_batch = self._generate_and_score_completions(generation_batch) + generation_batch = shuffle_tensor_dict(generation_batch) + self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation) + inputs = self._buffered_inputs[self._step % self.args.steps_per_generation] + self._step += 1 + else: + # In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence + # local generation batch == local eval batch + inputs = self._generate_and_score_completions(generation_batch) + return inputs + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompt_inputs = self.processing_class( + text=prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + # First, update the vLLM weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + if self.vllm_mode == "server": + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + completion_ids = [None] * len(all_prompts_text) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] + + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + elif self.vllm_mode == "colocate": + if self.guided_decoding_regex: + guided_decoding = GuidedDecodingParams(backend="outlines", regex=self.guided_decoding_regex) + else: + guided_decoding = None + sampling_params = SamplingParams( + n=1, # vLLM on each GPU generates only 1 in colocate mode + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding=guided_decoding, + ) + + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts_text) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts_text = prompts_text + + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False) + + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + + if self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + # Regular generation path + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + with ( + FSDP.summon_full_params(self.model_wrapped, recurse=False) + if self.is_fsdp_enabled + else nullcontext() + ): + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config + ) + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [ + [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask) + ] + + # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging + completion_lengths = completion_mask.sum(1) + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + truncated_completions = ~is_eos.any(dim=1) + completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + with torch.no_grad(): + # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps + # old_per_token_logps == per_token_logps, so we can skip it's computation here, and use + # per_token_logps.detach() instead. + if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps: + old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size + ) + else: + old_per_token_logps = None + + # Decode the generated completions + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + + # Repeat all input columns (but "prompt", "completion", and "completion_ids") to match the num of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion", "completion_ids"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + + for i, (reward_func, reward_processing_class, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names) + ): + with profiling_context(self, reward_func_name): + if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + warnings.warn( + f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards)) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + if self.scale_rewards: + advantages = advantages / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + all_process_advantages = advantages.clone() # keep the aggregated advantages for logging + advantages = advantages[process_slice] + + # Log the metrics + if mode == "train": + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + agg_completion_lengths = self.accelerator.gather(completion_lengths) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) + for i, reward_func_name in enumerate(self.reward_func_names): + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards) + std_rewards = nanstd(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_rewards) + self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) + + # Log prompt and completion texts + self._textual_logs["prompt"].extend(gather_object(prompts_text)) + self._textual_logs["completion"].extend(gather_object(completions_text)) + for i, name in enumerate(self.reward_func_names): + self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) + self._textual_logs["advantages"].extend(all_process_advantages.tolist()) + + return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "advantages": advantages, + "old_per_token_logps": old_per_token_logps, + } + + def compute_liger_loss(self, unwrapped_model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + # Compute the KL divergence between the model and the reference model + ref_per_token_logps = None + if self.beta != 0.0: + with torch.no_grad(): + if self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, input_ids, attention_mask, logits_to_keep + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, input_ids, attention_mask, logits_to_keep + ) + + # get the last hidden state of the model + last_hidden_state = self._get_last_hidden_state(unwrapped_model, input_ids, attention_mask, logits_to_keep) + + # compute loss and metrics using liger grpo loss + loss, metrics = self.liger_grpo_loss( + _input=last_hidden_state, + lin_weight=unwrapped_model.lm_head.weight, + selected_token_ids=completion_ids, + attention_mask=completion_mask, + advantages=inputs["advantages"], + bias=unwrapped_model.lm_head.bias, + old_per_token_logps=inputs["old_per_token_logps"], + ref_per_token_logps=ref_per_token_logps, + ) + # Extract metrics from the liger_grpo_loss output + # KL divergence is the first metric when beta is non-zero + mean_kl = metrics[0] if self.beta != 0.0 else None + clip_ratio = metrics[-1] + + mode = "train" if self.model.training else "eval" + if self.beta != 0.0: + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) + self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) + return loss + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + if self.use_liger_loss: + # Compute the loss using the liger grpo loss + unwrapped_model = self.accelerator.unwrap_model(model) + return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + else: + return self._compute_loss(model, inputs) + + def _compute_loss(self, model, inputs): + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + with torch.no_grad(): + if self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, input_ids, attention_mask, logits_to_keep + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, input_ids, attention_mask, logits_to_keep + ) + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps + # old_per_token_logps == per_token_logps, so we can skip it's computation + # (see _generate_and_score_completions) and use per_token_logps.detach() instead. + old_per_token_logps = ( + per_token_logps.detach() if inputs["old_per_token_logps"] is None else inputs["old_per_token_logps"] + ) + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == "grpo": + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + elif self.loss_type == "bnpo": + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == "dr_grpo": + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") + + # Log the metrics + mode = "train" if self.model.training else "eval" + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) + + # Compute the clipped probability ratios + is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) + is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0) + is_region_clipped = is_low_clipped | is_high_clipped + + low_clip = (is_low_clipped * completion_mask).sum() / completion_mask.sum() + high_clip = (is_high_clipped * completion_mask).sum() / completion_mask.sum() + clip_ratio = (is_region_clipped * completion_mask).sum() / completion_mask.sum() + + gathered_low_clip = self.accelerator.gather(low_clip) + self._metrics[mode]["clip_ratio/low_mean"].append(gathered_low_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/low_min"].append(nanmin(gathered_low_clip).item()) + gathered_high_clip = self.accelerator.gather(high_clip) + self._metrics[mode]["clip_ratio/high_mean"].append(gathered_high_clip.nanmean().item()) + self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item()) + gathered_clip_ratio = self.accelerator.gather(clip_ratio) + self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item()) + return loss + + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(): + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + loss = loss.mean().detach() + return loss, None, None + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + if self.accelerator.is_main_process and self.log_completions: + if is_rich_available(): + print_prompt_completions_sample( + self._textual_logs["prompt"], + self._textual_logs["completion"], + self._textual_logs["rewards"], + self._textual_logs["advantages"], + self.state.global_step, + self.num_completions_to_print, + ) + + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + table = { + "step": [str(self.state.global_step)] * len(self._textual_logs["prompt"]), + "prompt": self._textual_logs["prompt"], + "completion": self._textual_logs["completion"], + **self._textual_logs["rewards"], + "advantage": self._textual_logs["advantages"], + } + df = pd.DataFrame(table) + if self.wandb_log_unique_prompts: + df = df.drop_duplicates(subset=["prompt"]) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent( + """\ + @article{zhihong2024deepseekmath, + title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, + author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, + year = 2024, + eprint = {arXiv:2402.03300}, + } + """ + ) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="GRPO", + trainer_citation=citation, + paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", + paper_id="2402.03300", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/iterative_sft_config.py b/trl/trainer/iterative_sft_config.py new file mode 100644 index 0000000000000000000000000000000000000000..33eb72806a56aa62e1b2975f4a5d734ca42c2a79 --- /dev/null +++ b/trl/trainer/iterative_sft_config.py @@ -0,0 +1,102 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class IterativeSFTConfig(TrainingArguments): + r""" + Configuration class for the [`IterativeSFTTrainer`]. + + This class includes only the parameters that are specific to Iterative SFT training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`IterativeSFTTrainer`] is provided as a string. + + > Parameters that control the data preprocessing + + max_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + The truncation mode to use, either `"keep_end"` or `"keep_start"`. + optimize_device_cache (`bool`, *optional*, defaults to `False`): + Whether to optimize accelerator cache for slightly more memory-efficient training. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `IterativeSFTTrainer` is provided as a string." + }, + ) + + # Parameters that control the data preprocessing + max_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated." + }, + ) + truncation_mode: str = field( + default="keep_end", + metadata={"help": "The truncation mode to use, either 'keep_end' or 'keep_start'."}, + ) + optimize_device_cache: bool = field( + default=False, + metadata={"help": "Whether to optimize accelerator cache for slightly more memory-efficient training."}, + ) + + def __post_init__(self): + super().__post_init__() + + if self.truncation_mode not in ["keep_end", "keep_start"]: + raise ValueError(f"truncation_mode must be either 'keep_end' or 'keep_start', got {self.truncation_mode}") diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c82a22c26d0c9327ea2ea3d53216837fbd01e0 --- /dev/null +++ b/trl/trainer/iterative_sft_trainer.py @@ -0,0 +1,510 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import warnings +from pathlib import Path +from typing import Callable, Optional, Union + +import torch +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BaseImageProcessor, + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForSeq2Seq, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainingArguments, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available + +from ..core import PPODecorators +from .iterative_sft_config import IterativeSFTConfig +from .utils import generate_model_card, get_comet_experiment_url + + +if is_peft_available(): + from peft import PeftModel + + +if is_wandb_available(): + import wandb + + +class IterativeSFTTrainer(Trainer): + """ + The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + args ([`IterativeSFTConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator (`DataCollator`, *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance + of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or + tokenizer. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. + max_length (`int`, *optional*, deprecated): + Maximum length of the tokenized sequence. Use `args.max_length` instead. + truncation_mode (`str`, *optional*, deprecated): + The truncation mode to use. Use `args.truncation_mode` instead. + optimize_device_cache (`bool`, *optional*, deprecated): + Whether to optimize accelerator cache. Use `args.optimize_device_cache` instead. + """ + + _tag_names = ["trl", "iterative-sft"] + + def __init__( + self, + model: Union[str, PreTrainedModel], + args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None, + data_collator: Optional[DataCollator] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + # Deprecated parameters + max_length: Optional[int] = None, + truncation_mode: Optional[str] = None, + optimize_device_cache: Optional[bool] = None, + ): + # Handle deprecated parameters + deprecated_params = {} + if max_length is not None: + deprecated_params["max_length"] = max_length + warnings.warn( + "The `max_length` parameter is deprecated and will be removed in version 0.20. " + "Pass it through the `args` parameter using `IterativeSFTConfig(max_length=...)` instead.", + DeprecationWarning, + ) + if truncation_mode is not None: + deprecated_params["truncation_mode"] = truncation_mode + warnings.warn( + "The `truncation_mode` parameter is deprecated and will be removed in version 0.20. " + "Pass it through the `args` parameter using `IterativeSFTConfig(truncation_mode=...)` instead.", + DeprecationWarning, + ) + if optimize_device_cache is not None: + deprecated_params["optimize_device_cache"] = optimize_device_cache + warnings.warn( + "The `optimize_device_cache` parameter is deprecated and will be removed in version 0.20 " + "Pass it through the `args` parameter using `IterativeSFTConfig(optimize_device_cache=...)` instead.", + DeprecationWarning, + ) + + # Args + model_id = model if isinstance(model, str) else model.config._name_or_path + if args is None: + model_name = model_id.split("/")[-1] + args = IterativeSFTConfig(f"{model_name}-IterativeSFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = IterativeSFTConfig(**dict_args) + + # Update args with deprecated parameters if provided + if deprecated_params: + for key, value in deprecated_params.items(): + setattr(args, key, value) + + # Handle the tokenizer + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model_id) + + # Model + if args.model_init_kwargs is not None and not isinstance(model, str): + warnings.warn( + "You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + if isinstance(model, str): + model = self._create_model_from_path(model, args) + + # PEFT configuration and model wrapping + if is_peft_available() and isinstance(model, PeftModel): + self.is_peft_model = True + else: + self.is_peft_model = False + + self.processing_class = processing_class + self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False) + + if data_collator is None: + if self.is_encoder_decoder: + self.data_collator = DataCollatorForSeq2Seq( + processing_class, label_pad_token_id=-100, pad_to_multiple_of=8 + ) + else: + self.data_collator = DataCollatorForLanguageModeling(self.processing_class, mlm=False) + else: + self.data_collator = data_collator + + self.max_length = args.max_length + self.truncation_mode = args.truncation_mode + self.optimize_device_cache = args.optimize_device_cache + + super().__init__( + model=model, + args=args, + data_collator=self.data_collator, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + self.create_optimizer_and_scheduler(self.args.max_steps) + + # prepare model, optimizer and lr_scheduler + self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + self.processing_class.truncation_side = "left" if self.truncation_mode == "keep_end" else "right" + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + PPODecorators.optimize_device_cache = self.optimize_device_cache + + def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel: + """Creates a model from a path or model identifier.""" + model_init_kwargs = args.model_init_kwargs or {} + return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + + def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor): + if attention_mask is None: + attention_mask = [torch.ones_like(ids) for ids in input_ids] + + if self.is_encoder_decoder: + input_data = self.data_collator( + [ + {"input_ids": ids, "attention_mask": att, "labels": lab} + for ids, att, lab in zip(input_ids, attention_mask, labels) + ] + ).to(self.model.device) + + input_data.pop("decoder_input_ids", None) # This is directly computed inside the model + + input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100 + + else: + input_data = self.data_collator( + [{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)] + ).to(self.model.device) + + # truncate in case the user has provided input_ids, attention_mask and labels + if self.max_length is not None: + if self.truncation_mode == "keep_start": + input_data = {k: v[: self.max_length] for k, v in input_data.items()} + elif self.truncation_mode == "keep_end": + input_data = {k: v[-self.max_length :] for k, v in input_data.items()} + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + return input_data + + @staticmethod + def _step_safety_checker( + input_ids: list[torch.LongTensor], + attention_mask: list[torch.LongTensor], + labels: list[torch.LongTensor], + texts: list[str], + texts_labels: list[str], + ): + """ + Check if the input data is valid for training. + + Args: + input_ids (list[`torch.LongTensor`]): + List of tensors containing the input_ids + attention_mask (list[`torch.LongTensor`]): + List of tensors containing the attention_mask + labels (list[`torch.FloatTensor`]): + List of tensors containing the labels + texts (list[`str`]): + List of string containing the text input. + texts_labels (list[`str`]): + List of string containing the text labels. + + Returns: + `tuple`: The input data. + """ + if texts is None: + if attention_mask is None: + for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + for name, tensor_list in zip( + ["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels] + ): + if not isinstance(tensor_list, list): + raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}") + if not isinstance(tensor_list[0], torch.Tensor): + raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}") + else: + if not isinstance(texts, list): + raise ValueError(f"'text' must be a list of strings - got {type(texts)}") + if not isinstance(texts[0], str): + raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}") + if texts_labels is not None: + if not isinstance(texts_labels, list): + raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}") + if not isinstance(texts_labels[0], str): + raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}") + + return input_ids, attention_mask, labels, texts, texts_labels + + @PPODecorators.empty_device_cache() + def step( + self, + input_ids: Optional[list[torch.LongTensor]] = None, + attention_mask: Optional[list[torch.LongTensor]] = None, + labels: Optional[list[torch.LongTensor]] = None, + texts: Optional[list[str]] = None, + texts_labels: Optional[list[str]] = None, + ): + """ + Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels. + Args: + input_ids (list[`torch.LongTensor`]): + List of tensors containing the input_ids (if not provided, text will be used) + attention_mask (list[`torch.LongTensor`], , *optional*): + List of tensors containing the attention_mask + labels (list[`torch.FloatTensor`], *optional*): + List of tensors containing the labels (if set to None, will default to input_ids) + texts (list[`str`], *optional*): + List of strings containing the text input (if not provided, input_ids will directly be used) + texts_labels (list[`str`], *optional*): + List of strings containing the text labels (if set to None, will default to text) + + Returns: + `dict[str, Any]`: A summary of the training statistics + """ + self.model.train() + + if self.state.global_step == 0: + self.tr_loss = torch.tensor(0.0).to(self.args.device) + self._globalstep_last_logged = self.state.global_step + + if input_ids is None and texts is None: + raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.") + elif input_ids is not None and texts is not None: + warnings.warn( + "Both `input_ids` and `texts` argument are provided. `input_ids` will be ignored. " + "Please provide only one of the two.", + UserWarning, + ) + + if labels is None and texts_labels is None and self.is_encoder_decoder: + raise ValueError( + "No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed." + ) + + input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker( + input_ids, attention_mask, labels, texts, texts_labels + ) + + if texts is not None: + model_inputs = self.processing_class( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + ) + + input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"] + + if texts_labels is not None: + labels = self.processing_class( + texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt" + )["input_ids"] + + if labels is None: + labels = input_ids + + model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels) + + model_inputs_names = list(model_inputs.keys()) + + batch_dict = {} + batch_dict.update(model_inputs) + + def collator(data): + return_dict = dict() + for key in data[0]: + if key in ["input_ids", "attention_mask", "labels"]: + return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device) + return return_dict + + batch_data = Dataset.from_dict(batch_dict) + batch_data.set_format("torch") + + step_dataloader = DataLoader( + batch_data, + batch_size=self.args.per_device_train_batch_size, + shuffle=True, + collate_fn=collator, + ) + + for _, batch in enumerate(step_dataloader): + with self.accelerator.accumulate(self.model): + model_inputs = {k: batch[k] for k in model_inputs_names} + loss = self.compute_loss(self.model, model_inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() + + tr_loss_step = loss.detach() + + self.accelerator.backward(loss) + + if self.accelerator.sync_gradients and self.args.max_grad_norm is not None: + self.accelerator.clip_grad_norm_( + self.model.parameters(), + self.args.max_grad_norm, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + self.state.global_step += 1 + + # update stats etc + self.tr_loss += tr_loss_step + + self._maybe_log_save_evaluate() + + def _maybe_log_save_evaluate(self): + # check if eval is required + if self.args.eval_steps is not None: + if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0: + self.evaluate(self.eval_dataset) + + # check if logging is required + if self.args.logging_steps is not None: + if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0: + logs: dict[str, float] = {} + + tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item() + + # reset tr_loss to zero + self.tr_loss -= self.tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._globalstep_last_logged = self.state.global_step + + self.log(logs) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="Iterative SFT", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py new file mode 100644 index 0000000000000000000000000000000000000000..852fa06f68383872c1d8082bb33ca023a2bc555f --- /dev/null +++ b/trl/trainer/judges.py @@ -0,0 +1,457 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures +import logging +from abc import ABC, abstractmethod +from typing import Optional, Union + +import numpy as np +from accelerate import Accelerator +from huggingface_hub import InferenceClient +from transformers.utils import is_openai_available + +from ..import_utils import is_llm_blender_available + + +if is_llm_blender_available(): + import llm_blender + +if is_openai_available(): + from openai import OpenAI + + +DEFAULT_PAIRWISE_SYSTEM_PROMPT = '''I require a leaderboard for various large language models. I'll provide you with prompts given to these models and their corresponding outputs. Your task is to assess these responses, and select the model that produces the best output from a human perspective. + +## Instruction + +{{ + "instruction": """{prompt}""", +}} + +## Model Outputs + +Here are the unordered outputs from the models. Each output is associated with a specific model, identified by a unique model identifier. + +{{ + {{ + "model_identifier": "0", + "output": """{response0}""" + }}, + {{ + "model_identifier": "1", + "output": """{response1}""" + }} +}} + +## Task + +Evaluate the models on the basis of the quality and relevance of their results, and select the model that generated the best result. Reply with the identifier of the best model. Our evaluation will only take into account the first character of your answer, so make sure it contains only one of the identifiers and nothing else (no quotation marks, no spaces, no new lines, ...). +''' + + +class BaseJudge(ABC): + """ + Base class for judges. The subclasses of this class should implement the `judge` method. + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[str], shuffle_order: bool = True) -> list: + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BaseRankJudge(ABC): + """ + Base class for LLM ranking judges. + + **Example**: + ```python + class MyRankJudge(BaseRankJudge): + def judge(self, prompts, completions, shuffle_order=True): + return ... # Your ranking logic here + + judge = MyRankJudge() + judge.judge( + prompts=["The capital of France is", "The capital of Germany is"], + completions=[[" Paris", " Marseille", "Lyon"], [" Munich", " Berlin"]] + ) # [[0, 1, 2], [1, 0]] + ``` + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[list[int]]: + """ + Judge the completion for the given prompts and return the ranks of each completion. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[list[str]]`): + List of completions list, where each element is a list of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + `list[list[int]]`: + List of lists of idxs, where each list contains the ranks of the completions for the corresponding + prompt. E.g., `[1, 2, 0]` means that the second completion (`idx=1`) is the best, followed by the + third, and then the first. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BasePairwiseJudge(BaseJudge): + """ + Base class for pairwise judges. + """ + + @abstractmethod + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + """ + Judge the completion pairs for the given prompts. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[list[str]]`): + List of completions pairs, where each element is a pair of completions for the corresponding prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + `list[int]`: + List of idxs, where each idx is the rank of the best completion for the corresponding prompt. + E.g., `1` means that the second completion (`idx=1`) is the best. + + Note: + If the judge returns `-1` for any prompt, it indicates that the inner process used to compute the + preference has failed. For instance, this could occur if the underlying language model returned an invalid + answer. In such cases, the caller should handle these invalid indices appropriately, possibly by + implementing fallback logic or error handling. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class BaseBinaryJudge(BaseJudge): + """ + Base class for binary judges. + """ + + @abstractmethod + def judge( + self, + prompts: list[str], + completions: list[str], + gold_completions: Optional[list[str]] = None, + shuffle_order: bool = True, + ) -> list[int]: + """ + Judge the completion for a given prompt. Used to assess if a completion satisfies a constraint. + + This base class should be used to implement binary evaluations as done in section 4.1.4 of the + [CGPO paper](https://huggingface.co/papers/2409.20370). + It is relevant for assessing whether a prompt completion pair satisfies a specific contraint. + + Args: + prompts (`list[str]`): List of prompts. + completions (`list[str]`): List of completions. + gold_completions (`list[str]`, `optional`): List of gold completions if it exists. + shuffle_order (`bool`): Whether to shuffle the order of the completions to avoid positional bias. + + Returns: + list[int]: A list of binary labels: + - 1 indicates that the completion satisfies the evaluated constraint. + - 0 indicates that the completion does not satisfy the evaluated constraint. + + Note: + If the judge returns -1 for any prompt, it indicates that the inner process used to compute the preference has failed. + For instance, this could occur if the underlying language model or rule based contraint returned an invalid answer. + In such cases, the caller should handle these invalid indices appropriately, possibly by implementing fallback logic or error handling. + """ + raise NotImplementedError("Judge subclasses must implement the `judge` method.") + + +class PairRMJudge(BasePairwiseJudge): + """ + LLM judge based on the PairRM model from AllenAI. + + This judge uses the PairRM model to rank pairs of completions for given prompts. It's designed for pairwise + comparison of language model outputs. The PairRM model is loaded using the llm-blender library and runs on the + default Accelerator device. + + **Attributes**: + + blender (`llm_blender.Blender`): + An instance of the Blender class from llm-blender. + + **Example**: + ```python + >>> pairrm_judge = PairRMJudge() + >>> prompts = ["Translate 'hello' to French", "What's the capital of Japan?"] + >>> completions = [["Bonjour", "Salut"], ["Kyoto", "Tokyo"]] + >>> results = pairrm_judge.judge(prompts, completions) + >>> print(results) # [0, 1] (indicating the first completion is preferred for the first prompt and the second) + ``` + + + + This class requires the llm-blender library to be installed. Install it with: `pip install llm-blender`. + + + """ + + def __init__(self): + if not is_llm_blender_available(): + raise ValueError("llm-blender is not installed. Please install it with `pip install llm-blender`.") + self.blender = llm_blender.Blender() + self.blender.loadranker("llm-blender/PairRM", device=Accelerator().device) + + def judge( + self, + prompts: list[str], + completions: list[list[str]], + shuffle_order: bool = True, + return_scores: bool = False, + temperature: float = 1.0, + ) -> list[Union[int, float]]: + """ + Judge the completion pairs for the given prompts using the PairRM model. + + Args: + prompts (`list[str]`): + List of prompts to judge. + completions (`list[list[str]]`): + List of completion pairs for each prompt. + shuffle_order (`bool`, *optional*, defaults to `True`): + Whether to shuffle the order of the completions to avoid positional bias. + return_scores (`bool`, *optional*, defaults to `False`): + If `True`, return probability scores of the first completion instead of ranks (i.e. a *soft-judge*). + temperature (`float`, *optional*, defaults to `1.0`): + Temperature for scaling logits if `return_scores` is True. + + Returns: + `Union[list[int, float]]`: + If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which + completion is preferred. + If `return_scores` is `True`, returns softmax probabilities for the first completion. + + Raises: + `ValueError`: + If the number of completions per prompt is not exactly 2. + + Note: + Unlike llm-blender, ranks are 0-indexed (`0` means the first completion is preferred). + """ + + if len(completions[0]) != 2: + raise ValueError("PairRM judge requires exactly 2 completions per prompt.") + + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)] + + # Rank the completions + ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True) + if not return_scores: + ranks -= 1 # PairRM rank is 1-indexed, so we subtract 1 to make it 0-indexed + else: + # scale the logits by temperature + ranks /= temperature + + # Flip back the ranks or scores to the original order if needed + if shuffle_order: + ranks[flip_mask] = ranks[flip_mask][:, ::-1] + + # Return the ranks or score probability + if return_scores: + logit_max = np.amax(ranks, axis=-1, keepdims=True) + exp_logit_shifted = np.exp(ranks - logit_max) + probs = exp_logit_shifted / np.sum(exp_logit_shifted, axis=-1, keepdims=True) + return probs[:, 0].tolist() + else: + return ranks[:, 0].tolist() + + +class HfPairwiseJudge(BasePairwiseJudge): + """ + Pairwise judge based on the Hugging Face API with chat completion. + + This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. + + Args: + model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): + Model to use for the judge. + token (`str`, *optional*): + Hugging Face API token to use for the [`huggingface_hub.InferenceClient`]. + system_prompt (`str` or `None`, *optional*, defaults to `None`): + The system prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. + """ + + def __init__( + self, + model="meta-llama/Meta-Llama-3-70B-Instruct", + token: Optional[str] = None, + system_prompt: Optional[str] = None, + ): + self.client = InferenceClient(model=model, token=token) + self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT + + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)] + + # Define a function to get the rank for a single prompt, will be called concurrently + def get_rank(prompt, candidates): + content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1]) + completion = self.client.chat_completion(messages=[{"role": "user", "content": content}], max_tokens=1) + response = completion.choices[0].message.content + if response in ["0", "1"]: + return int(response) + else: + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") + return -1 + + # Call the completions concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + ranks = list(executor.map(get_rank, prompts, completions)) + + # Flip back the ranks to the original order if needed + if shuffle_order: + ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)] + + # Return the ranks + return ranks + + +class OpenAIPairwiseJudge(BasePairwiseJudge): + """ + Judge based on the OpenAI API. + + This judge is relevant for assessing the quality chat models, where the completion is a response to a given prompt. + + Args: + model (`str`, *optional*, defaults to `"gpt-4-turbo-preview"`): + Model to use for the judge. + system_prompt (`str` or `None`, *optional*, defaults to `None`): + System prompt to be used for the judge. If not provided, a default prompt is used. Note that the system + prompt should contain the following placeholders: `{prompt}`, `{response0}`, and `{response1}`. Also, the + inference is called with `max_tokens=1`, consequently the system prompt should ask for a single token + response. + max_requests (`int` or `None`, *optional*, defaults to `1000`): + Maximum number of requests to make to the OpenAI API. If set to `None`, there is no limit. + """ + + def __init__( + self, model="gpt-4-turbo-preview", system_prompt: Optional[str] = None, max_requests: Union[int, None] = 1_000 + ): + if not is_openai_available(): + raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.") + self.client = OpenAI() + self.model = model + self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT + self.max_requests = max_requests + self.num_requests = 0 + self._warned = False + + def judge(self, prompts: list[str], completions: list[list[str]], shuffle_order: bool = True) -> list[int]: + # Check if the limit of requests is reached, if so, use random choice instead + if self.max_requests is not None and self.num_requests >= self.max_requests: + if not self._warned: # Print the warning only once + logging.warning( + f"Reached the maximum number of requests ({self.max_requests}). From now on, returning -1 instead. " + " To increase the limit, set `max_requests` to a higher value, or to `None` for no limit." + ) + self._warned = True + return [-1] * len(prompts) + + # Shuffle the order of the completions to avoid positional bias + if shuffle_order: + flip_mask = np.random.choice([True, False], size=len(prompts)) + completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)] + + # Define a function to get the rank for a single prompt, will be called concurrently + def get_rank(prompt, candidates): + content = self.system_prompt.format(prompt=prompt, response0=candidates[0], response1=candidates[1]) + messages = [{"role": "user", "content": content}] + completion = self.client.chat.completions.create(model=self.model, messages=messages, max_tokens=1) + response = completion.choices[0].message.content + if response in ["0", "1"]: + return int(response) + else: + logging.debug(f"Invalid response from the judge model: '{response}'. Returning -1.") + return -1 + + # Call the completions concurrently + with concurrent.futures.ThreadPoolExecutor() as executor: + ranks = list(executor.map(get_rank, prompts, completions)) + + # Flip back the ranks to the original order if needed + if shuffle_order: + ranks = [ranks[i] if not flip else 1 - ranks[i] for i, flip in enumerate(flip_mask)] + + # Update the number of requests + self.num_requests += len(prompts) + + # Return the ranks + return ranks + + +class AllTrueJudge(BaseBinaryJudge): + """ + Unify the decision of multiple [`BaseBinaryJudge`] instances. + + Returns `1` only if all inner binary judges return `1`. If any judge returns `0`, it returns `0`. + If any judge returns `-1`, indicating a failure in its process, this judge will also return `-1`. + + Implements the Mixture of Judges as described in the [CGPO paper](https://huggingface.co/papers/2409.20370). + + Args: + judges (`list[BaseBinaryJudge]`): A list of [`BaseBinaryJudge`] instances whose decisions will be unified. + """ + + def __init__(self, judges: list[BaseBinaryJudge]): + self.judges = judges + + def judge( + self, + prompts: list[str], + completions: list[str], + gold_completions: Optional[list[str]] = None, + shuffle_order: bool = True, + ) -> list[int]: + all_binary_judgments = [ + judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges + ] + output = [] + for binary_judgments in zip(*all_binary_judgments): + # Check that all values are in {0, 1, -1} + if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments): + raise ValueError( + f"Invalid binary judgment: {binary_judgments}, expected list of values in {{0, 1, -1}}." + ) + + # Unify the decision + if -1 in binary_judgments: + output.append(-1) + elif all(binary_judgment == 1 for binary_judgment in binary_judgments): + output.append(1) + else: + output.append(0) + return output diff --git a/trl/trainer/kto_config.py b/trl/trainer/kto_config.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5a533154eb31281f686b0baac071a31bf529a2 --- /dev/null +++ b/trl/trainer/kto_config.py @@ -0,0 +1,232 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class KTOConfig(TrainingArguments): + r""" + Configuration class for the [`KTOTrainer`]. + + This class includes only the parameters that are specific to KTO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. + loss_type (`str`, *optional*, defaults to `"kto"`): + Type of loss to use. Possible values are: + + - `"kto"`: KTO loss from the [KTO](https://huggingface.co/papers/2402.01306) paper. + - `"apo_zero_unpaired"`: Unpaired variant of APO-zero loss from the [APO](https://huggingface.co/papers/2408.06266) paper. + + desirable_weight (`float`, *optional*, defaults to `1.0`): + Desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris. + undesirable_weight (`float`, *optional*, defaults to `1.0`): + Undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during + evaluation. + is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): + Whether to precompute reference model log probabilities for training and evaluation datasets. This is + useful when training without the reference model to reduce the total GPU memory needed. + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model + from a string. + dataset_num_proc: (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + use_liger_loss (`bool`, *optional*, defaults to `False`): + Whether to use Liger loss. It requires liger-kernel to be installed. + base_model_attribute_name (`str`, *optional*, defaults to `"model"`): + Name of the attribute in the model that contains the base model. This is used to get the base model from + the model when the model does not have a `get_decoder` method in the case when `use_liger_loss` is `True`. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs", "ref_model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " + "the reference model." + }, + ) + loss_type: str = field( + default="kto", + metadata={ + "help": "Type of loss to use.", + "choices": ["kto", "apo_zero_unpaired"], + }, + ) + desirable_weight: float = field( + default=1.0, + metadata={ + "help": "Desirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + undesirable_weight: float = field( + default=1.0, + metadata={ + "help": "Undesirable losses are weighed by this factor to counter unequal number of desirable and " + "undesirable pairs.", + }, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={ + "help": "If `True`, generates and logs completions from both the model and the reference model to W&B " + "during evaluation." + }, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + precompute_ref_log_probs: bool = field( + default=False, + metadata={ + "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " + "This is useful when training without the reference model to reduce the total GPU memory needed." + }, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + ref_model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " + "reference model from a string." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + use_liger_loss: bool = field( + default=False, + metadata={"help": "Whether to use Liger loss. It requires liger-kernel to be installed."}, + ) + base_model_attribute_name: str = field( + default="model", + metadata={ + "help": "Name of the attribute in the model that contains the base model. This is used to get the base " + "model from the model when the model does not have a `get_decoder` method in the case when " + "`use_liger_loss` is `True`." + }, + ) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa01d1ca06c58fbf4d59eaf385bfcb640d181dcf --- /dev/null +++ b/trl/trainer/kto_trainer.py @@ -0,0 +1,1716 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from accelerate.utils import tqdm +from datasets import Dataset, concatenate_datasets +from torch import autocast +from torch.utils.data import DataLoader, SequentialSampler +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainingArguments, + is_comet_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalLoopOutput, has_length +from transformers.utils import is_peft_available + +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset +from ..import_utils import is_liger_kernel_available +from ..models import create_reference_model, prepare_deepspeed +from .kto_config import KTOConfig +from .utils import ( + DPODataCollatorWithPadding, + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_liger_kernel_available(): + from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +RUNNING_NAME = "running.pt" + + +def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]: + """ + Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions. + For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched + outputs y used to estimate the rewards in that batch, just paired with different x. + """ + batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1] + batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1] + return batch + + +def _tokenize( + batch: dict[str, list[Any]], + tokenizer: "PreTrainedTokenizer", +) -> dict[str, list[Any]]: + """Tokenize a batch from a KTO specific dataset.""" + prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) + prompt_input_ids = prompt_tokenized["input_ids"] + prompt_attention_mask = prompt_tokenized["attention_mask"] + prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])] + full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) + full_input_ids = full_tokenized["input_ids"] + full_attention_mask = full_tokenized["attention_mask"] + + answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)] + answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)] + # Prepare input tokens for token by token comparison + full_input_ids = [np.array(f) for f in full_input_ids] + for full, concat in zip(full_input_ids, full_concat_input_ids): + if len(full) != len(concat): + raise ValueError( + "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." + ) + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = [len(p) for p in prompt_input_ids] + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)): + if not np.array_equal(p, f[:r]): + response_token_ids_start_idx[idx] -= 1 + + prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)] + prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] + + for p, m in zip(prompt_input_ids, prompt_attention_mask): + if len(p) != len(m): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)] + answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)] + + output = dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + answer_input_ids=answer_input_ids, + answer_attention_mask=answer_attention_mask, + ) + + return output + + +def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: + """Process tokens of a KTO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + completion responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the completion. + + We also create the labels for the completion responses, which are of length equal to + the sum of the length of the prompt and the completion response, with + label_pad_token_id for the prompt tokens. + """ + prompt = example["prompt"] + completion = example["completion"] + + batch = { + f"{kwargs['prefix']}prompt": prompt, + f"{kwargs['prefix']}completion": completion, + f"{kwargs['prefix']}label": example["label"], + } + + if not kwargs["is_encoder_decoder"]: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + + if not isinstance(completion, str): + raise ValueError(f"completion should be an str but got {type(completion)}") + + # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer + all_tokens = { + "prompt_input_ids": example["prompt_input_ids"], + "prompt_attention_mask": example["prompt_attention_mask"], + "answer_input_ids": example["answer_input_ids"], + "answer_attention_mask": example["answer_attention_mask"], + } + + # calculate max length by checking if BOS/EOS is already there + max_length = kwargs["max_length"] + bos_token_id = kwargs["tokenizer"].bos_token_id + eos_token_id = kwargs["tokenizer"].eos_token_id + if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: + max_length -= 1 + if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: + max_length -= 1 + + # if combined sequence is too long (> max_length - 1 for BOS token - 1 for EOS), truncate the prompt + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["prompt_input_ids", "prompt_attention_mask"]: + if kwargs["truncation_mode"] == "keep_start": + all_tokens[k] = all_tokens[k][: kwargs["max_prompt_length"]] + elif kwargs["truncation_mode"] == "keep_end": + all_tokens[k] = all_tokens[k][-kwargs["max_prompt_length"] :] + else: + raise ValueError(f"Unknown truncation mode: {kwargs['truncation_mode']}") + + # if that's still too long, truncate the response + if len(all_tokens["prompt_input_ids"]) + len(all_tokens["answer_input_ids"]) > max_length: + for k in ["answer_input_ids", "answer_attention_mask"]: + all_tokens[k] = all_tokens[k][: max_length - kwargs["max_prompt_length"]] + + # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens + batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] + batch[f"{kwargs['prefix']}completion_input_ids"] = ( + all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] + ) + batch[f"{kwargs['prefix']}completion_attention_mask"] = ( + all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] + ) + + # add BOS, which affects both prompt and the full completion + if bos_token_id is not None: + if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: + batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}prompt_input_ids" + ] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}prompt_attention_mask" + ] + batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ + f"{kwargs['prefix']}completion_input_ids" + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + # add EOS, which affects only the full completion + if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: + batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ + eos_token_id + ] + batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ + f"{kwargs['prefix']}completion_attention_mask" + ] + [1] + + batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] + batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [ + kwargs["label_pad_token_id"] + ] * len(batch[f"{kwargs['prefix']}prompt_input_ids"]) + else: + completion_tokens = kwargs["tokenizer"]( + completion, truncation=True, max_length=kwargs["max_completion_length"], add_special_tokens=True + ) + prompt_tokens = kwargs["tokenizer"]( + prompt, truncation=True, max_length=kwargs["max_prompt_length"], add_special_tokens=True + ) + + batch[f"{kwargs['prefix']}prompt_input_ids"] = prompt_tokens["input_ids"] + batch[f"{kwargs['prefix']}prompt_attention_mask"] = prompt_tokens["attention_mask"] + + batch[f"{kwargs['prefix']}completion_labels"] = completion_tokens["input_ids"] + batch[f"{kwargs['prefix']}completion_attention_mask"] = completion_tokens["attention_mask"] + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch[f"{kwargs['prefix']}completion_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["completion_labels"]) + ) + + return batch + + +class KTOTrainer(Trainer): + r""" + Initialize KTOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + args (`KTOConfig`): + The arguments to use for training. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + data_collator (`transformers.DataCollator`, *optional*, defaults to `None`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + model_adapter_name (`str`, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + """ + + _tag_names = ["trl", "kto"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module, str] = None, + ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: KTOConfig = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + data_collator: Optional[DataCollator] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + model_adapter_name: Optional[str] = None, + ref_adapter_name: Optional[str] = None, + ): + if type(args) is TrainingArguments: + raise ValueError("Please use `KTOConfig` instead TrainingArguments.") + + if not isinstance(model, str) and ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must mass a copy of it, or `None` if you use peft." + ) + + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + + if args.ref_model_init_kwargs is None: + ref_model_init_kwargs = {} + elif not isinstance(ref_model, str): + raise ValueError( + "You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated." + ) + else: + ref_model_init_kwargs = args.ref_model_init_kwargs + torch_dtype = ref_model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + ref_model_init_kwargs["torch_dtype"] = torch_dtype + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + if isinstance(ref_model, str): + ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + self.is_peft_model = is_peft_available() and isinstance(model, PeftModel) + self.model_adapter_name = model_adapter_name + self.ref_adapter_name = ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model or args.precompute_ref_log_probs: + # The `model` with adapters turned off will be used as the reference model + self.ref_model = None + else: + self.ref_model = create_reference_model(model) + + if processing_class is None: + raise ValueError( + "max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding" + ) + if args.max_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" + " it will be set to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + if args.max_length is not None: + max_length = args.max_length + + if args.max_prompt_length is None: + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + if args.max_prompt_length is not None: + max_prompt_length = args.max_prompt_length + + max_completion_length = None + if args.max_completion_length is None and self.is_encoder_decoder: + warnings.warn( + "When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init" + " it will be set to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_completion_length = 128 + if args.max_completion_length is not None and self.is_encoder_decoder: + max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + self.loss_type = args.loss_type + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.max_completion_length = max_completion_length + self.processing_class = processing_class + self.precompute_ref_log_probs = args.precompute_ref_log_probs + + # Not all losses require a KL calculation + self.calculate_KL = True + if self.loss_type in ["apo_zero_unpaired"]: + self.calculate_KL = False + + # Since ref_logs are precomputed on the first call to get_train/eval_dataloader + # keep track of first called to avoid computation of future calls + self._precomputed_train_ref_log_probs = False + self._precomputed_eval_ref_log_probs = False + + # metric + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # KTO parameter + self.beta = args.beta + self.desirable_weight = args.desirable_weight + self.undesirable_weight = args.undesirable_weight + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + warnings.warn( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + UserWarning, + ) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed + train_dataset = train_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset" + ) + # Unpair the dataset if needed + train_dataset = maybe_unpair_preference_dataset( + train_dataset, args.dataset_num_proc, desc="Unpairing train dataset" + ) + # Apply the chat template if needed + train_dataset = train_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to train dataset", + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset" + ) + eval_dataset = maybe_unpair_preference_dataset( + eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset" + ) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + desc="Applying chat template to eval dataset", + ) + + # Tokenize and prepare the training datasets + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs={"tokenizer": self.processing_class}, + num_proc=args.dataset_num_proc, + desc="Tokenizing train dataset", + ) + + fn_kwargs = { + "prefix": "", + "is_encoder_decoder": self.is_encoder_decoder, + "tokenizer": self.processing_class, + "max_length": self.max_length, + "truncation_mode": self.truncation_mode, + "label_pad_token_id": self.label_pad_token_id, + "max_prompt_length": self.max_prompt_length, + "max_completion_length": self.max_completion_length, + } + + train_dataset = train_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized train dataset", + ) + + # Tokenize and prepare the eval datasets + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs={"tokenizer": self.processing_class}, + batched=True, + num_proc=args.dataset_num_proc, + desc="Tokenizing eval dataset", + ) + + eval_dataset = eval_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + desc="Processing tokenized eval dataset", + ) + + # Get KL datasets if needed + if self.calculate_KL: + if args.per_device_train_batch_size <= 1: + raise ValueError( + "Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward." + ) + + # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size + # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) + train_kl_dataset = train_dataset.map( + _get_kl_dataset, + batched=True, + batch_size=args.per_device_train_batch_size, + num_proc=args.dataset_num_proc, + desc="Extracting KL train dataset", + ) + + fn_kwargs["prefix"] = "KL_" + train_kl_dataset = train_kl_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names], + desc="Processing tokenized train KL dataset", + ) + + # merge the datasets + train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1) + + if eval_dataset is not None: + # Get KL dataset + eval_kl_dataset = eval_dataset.map( + _get_kl_dataset, + batched=True, + batch_size=args.per_device_train_batch_size, + num_proc=args.dataset_num_proc, + desc="Extracting eval KL dataset", + ) + + eval_kl_dataset = eval_kl_dataset.map( + _process_tokens, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names], + desc="Processing tokenized eval KL dataset", + ) + + # merge the datasets + eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1) + + # calculate dataset desirability balance + num_desirable = max(sum(train_dataset["label"]), 1) + num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary + + if num_desirable != num_undesirable: + # The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306 + des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2) + des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2) + und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2) + und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2) + + des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound + und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound + + if not (des_weight_in_range or und_weight_in_range): + warnings.warn( + "You have different amounts of desirable/positive and undesirable/negative examples but the " + "weights on the desirable and undesirable losses don't seem to be in an ideal range. Based " + f"on your data, we recommend EITHER " + f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or " + f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). " + "See the documentation on how to optimally set these weights.", + UserWarning, + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + # Deepspeed Zero-3 does not support precompute_ref_log_probs + if self.is_deepspeed_enabled: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`." + ) + + if self.ref_model is None: + if not (self.is_peft_model or self.precompute_ref_log_probs): + raise ValueError( + "No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`" + ) + else: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + # Import Liger loss if enabled + if self.args.use_liger_loss: + if not is_liger_kernel_available(): + raise ImportError( + "You set `use_liger_loss=True` but the liger kernel is not available. " + "Please install liger-kernel first: `pip install liger-kernel`" + ) + if self.loss_type in ["apo_zero_unpaired"]: + raise ValueError( + "You cannot set `loss_type='apo_zero_unpaired'` with liger-kernel." + "Only KTO loss is supported with liger-kernel." + ) + if self.precompute_ref_log_probs: + raise ValueError( + "You cannot use `precompute_ref_log_probs=True` with liger kernel. Please set " + "`precompute_ref_log_probs=False`." + ) + if self.is_peft_model or self.ref_adapter_name is not None: + raise ValueError( + "You cannot use `use_liger_loss=True` with Peft models. Please set `use_liger_loss=False`." + ) + self.kto_loss_fn = LigerFusedLinearKTOLoss( + ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None) + ) + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.set_adapter(self.model_adapter_name or "default") + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute `ref_log_probs`. + """ + + if self.precompute_ref_log_probs and not self._precomputed_train_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_train_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(self.train_dataset, **dataloader_params)) + reference_completion_logps = [] + reference_KL_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Train dataset reference log probs"): + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + if self.calculate_KL: + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + self.train_dataset = self.train_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + + if self.calculate_KL: + self.train_dataset = self.train_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) + + self._precomputed_train_ref_log_probs = True + + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute `ref_log_probs`. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + + if self.precompute_ref_log_probs and not self._precomputed_eval_ref_log_probs: + dataloader_params = { + "batch_size": self.args.per_device_eval_batch_size, + "collate_fn": self.data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "shuffle": False, + } + + # prepare dataloader + data_loader = self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) + + reference_completion_logps = [] + reference_KL_logps = [] + + for padded_batch in tqdm(iterable=data_loader, desc="Eval dataset reference log probs"): + reference_completion_logp, reference_KL_logp = self.compute_reference_log_probs(padded_batch) + + reference_completion_logp = self.accelerator.gather_for_metrics(reference_completion_logp) + reference_completion_logps.append(reference_completion_logp.cpu()) + + if self.calculate_KL: + reference_KL_logp = self.accelerator.gather_for_metrics(reference_KL_logp) + reference_KL_logps.append(reference_KL_logp.cpu()) + + eval_dataset = eval_dataset.add_column( + name="reference_logps", column=torch.cat(reference_completion_logps).float().numpy() + ) + if self.calculate_KL: + eval_dataset = eval_dataset.add_column( + name="reference_KL_logps", column=torch.cat(reference_KL_logps).float().numpy() + ) + + # Save calculated reference_chosen_logps and reference_rejected_logps to the eval_dataset for subsequent runs + if self.eval_dataset is not None: + self.eval_dataset = eval_dataset + self._precomputed_eval_ref_log_probs = True + + return super().get_eval_dataloader(eval_dataset=eval_dataset) + + def compute_reference_log_probs(self, padded_batch: dict) -> dict: + """Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset.""" + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + if self.is_encoder_decoder: + completion_logits = self.model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + if self.calculate_KL: + KL_logits = self.model( + padded_batch["KL_prompt_input_ids"], + attention_mask=padded_batch["KL_prompt_attention_mask"], + decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), + labels=padded_batch["KL_completion_labels"], + ).logits + else: + completion_logits = self.model( + padded_batch["completion_input_ids"], + attention_mask=padded_batch["completion_attention_mask"], + ).logits + + if self.calculate_KL: + KL_logits = self.model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits + else: + if self.is_encoder_decoder: + completion_logits = self.ref_model( + padded_batch["prompt_input_ids"], + attention_mask=padded_batch["prompt_attention_mask"], + decoder_input_ids=padded_batch.get("completion_decoder_input_ids"), + labels=padded_batch["completion_labels"], + ).logits + + if self.calculate_KL: + KL_logits = self.ref_model( + padded_batch["KL_prompt_input_ids"], + attention_mask=padded_batch["KL_prompt_attention_mask"], + decoder_input_ids=padded_batch.get("KL_completion_decoder_input_ids"), + labels=padded_batch["KL_completion_labels"], + ).logits + else: + completion_logits = self.ref_model( + padded_batch["completion_input_ids"], attention_mask=padded_batch["completion_attention_mask"] + ).logits + + if self.calculate_KL: + KL_logits = self.ref_model( + padded_batch["KL_completion_input_ids"], + attention_mask=padded_batch["KL_completion_attention_mask"], + ).logits + + completion_logps = self.get_batch_logps( + completion_logits, + padded_batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + if self.calculate_KL: + KL_logps = self.get_batch_logps( + KL_logits, + padded_batch["KL_completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + else: + KL_logps = None + + return completion_logps, KL_logps + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + else: + # Fixes end-dec RuntimeError + labels = labels.clone() + + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == label_pad_token_id] = 0 + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + KL_logps = self._compute_kl_logps(model, batch) + + model_kwargs = ( + { + "labels": batch["completion_labels"], + "decoder_input_ids": batch.get("completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else {} + ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + **model_kwargs, + ) + completion_logits = outputs.logits + + completion_logps = self.get_batch_logps( + completion_logits, + batch["completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + if completion_logps.shape[0] != len(batch["label"]): + raise ValueError( + "There is a mismatch between the number of examples in this batch and the number of " + "examples for which an output sequence was predicted." + ) + + chosen_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(completion_logps.shape[0]) if batch["label"][i] is False] + + chosen_logps = completion_logps[chosen_idx, ...] + rejected_logps = completion_logps[rejected_idx, ...] + + chosen_logits = completion_logits[chosen_idx, ...] + rejected_logits = completion_logits[rejected_idx, ...] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps, outputs.aux_loss) + else: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, KL_logps) + + def kto_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + policy_KL_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_KL_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the KTO loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (num(chosen) in batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (num(rejected) in batch_size,) + policy_KL_logps: Log probabilities of the policy model for the KL responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (num(chosen) in batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (num(rejected) in batch_size,) + reference_KL_logps: Log probabilities of the reference model for the KL responses. Shape: (batch_size,) + + Returns: + A tuple of four tensors: (losses, chosen_rewards, rejected_rewards, KL). + The losses tensor contains the KTO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The KL tensor contains the detached KL divergence estimate between the policy and reference models. + """ + if self.calculate_KL: + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) + else: + kl = torch.zeros(1).to(policy_chosen_logps.device) + + # Chosen losses + if policy_chosen_logps.shape[0] != 0 or reference_chosen_logps.shape[0] != 0: + chosen_logratios = policy_chosen_logps - reference_chosen_logps + + if self.loss_type == "kto": + # Eqn (7) of the KTO paper (https://huggingface.co/papers/2402.01306) + chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - kl)) + elif self.loss_type == "apo_zero_unpaired": + # Unpaired variant of Eqn (7) of the APO paper (https://huggingface.co/papers/2408.06266) + # Use this loss when you believe the chosen outputs are better than your model's default output + chosen_losses = 1 - F.sigmoid(self.beta * chosen_logratios) + + chosen_rewards = self.beta * chosen_logratios.detach() + + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + chosen_losses = torch.Tensor([]).to(self.accelerator.device) + chosen_rewards = torch.Tensor([]).to(self.accelerator.device) + + # Rejected losses + if policy_rejected_logps.shape[0] != 0 or reference_rejected_logps.shape[0] != 0: + rejected_logratios = policy_rejected_logps - reference_rejected_logps + + if self.loss_type == "kto": + rejected_losses = 1 - F.sigmoid(self.beta * (kl - rejected_logratios)) + elif self.loss_type == "apo_zero_unpaired": + rejected_losses = F.sigmoid(self.beta * rejected_logratios) + + rejected_rewards = self.beta * rejected_logratios.detach() + else: + # lists can't be empty -- if they are, then accelerate.gather will hang + rejected_losses = torch.Tensor([]).to(self.accelerator.device) + rejected_rewards = torch.Tensor([]).to(self.accelerator.device) + + losses = torch.cat( + (self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), + 0, + ) + + return losses, chosen_rewards, rejected_rewards, kl + + def _compute_kl_logps(self, model, batch): + """Compute KL log probabilities for a given batch.""" + KL_logps = None + if self.calculate_KL: + if self.is_encoder_decoder: + KL_model_kwargs = { + "input_ids": batch["KL_prompt_input_ids"], + "attention_mask": batch["KL_prompt_attention_mask"], + "labels": batch["KL_completion_labels"], + "decoder_input_ids": batch.get("KL_completion_decoder_input_ids"), + } + else: + KL_model_kwargs = { + "input_ids": batch["KL_completion_input_ids"], + "attention_mask": batch["KL_completion_attention_mask"], + } + + with torch.no_grad(): + KL_logits = model(**KL_model_kwargs).logits + + KL_logps = self.get_batch_logps( + KL_logits, + batch["KL_completion_labels"], + average_log_prob=False, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + return KL_logps + + def _compute_loss_liger(self, model, batch): + """ + Compute the KTO loss using the Liger-Kernel's LigerFusedLinearKTOLoss. + + Args: + model: The policy model used for generating log probabilities and outputs. It could be an encoder-decoder model or a regular language model. + batch: A dictionary containing the input data and labels for the batch. + + Returns: + A dictionary containing the following keys: + - "loss": The computed KTO loss for the batch. + - "chosen_logits_sum": Sum of the logits for the chosen responses from the policy model. + - "rejected_logits_sum": Sum of the logits for the rejected responses from the policy model. + - "chosen_logps": Log probabilities of the chosen responses from the policy model. + - "rejected_logps": Log probabilities of the rejected responses from the policy model. + - "chosen_rewards": Rewards for the chosen responses. + - "rejected_rewards": Rewards for the rejected responses. + - "kl": The KL divergence between the policy and reference models (detached). + + If auxiliary loss is enabled, the dictionary will also include: + - "aux_loss": The auxiliary loss from the model outputs. + """ + policy_KL_logps = self._compute_kl_logps(model, batch) + reference_KL_logps = self._compute_kl_logps(self.ref_model, batch) + if self.calculate_KL: + kl = (policy_KL_logps - reference_KL_logps).mean().detach() + kl = self.accelerator.gather_for_metrics(kl).mean().clamp(min=0) + else: + kl = torch.zeros(1).to(self.accelerator.device) + + model_kwargs = ( + { + "labels": batch["completion_labels"], + "decoder_input_ids": batch.get("completion_decoder_input_ids"), + } + if self.is_encoder_decoder + else {} + ) + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + if self.is_encoder_decoder: + # 1. Get encoder outputs + encoder_outputs = model.get_encoder()( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + return_dict=True, + **model_kwargs, + ) + # 2. Get decoder outputs + outputs = model.get_decoder()( + input_ids=model_kwargs["decoder_input_ids"], + encoder_hidden_states=encoder_outputs.last_hidden_state, + use_cache=False, + **model_kwargs, + ) + # 1. Get reference encoder outputs + ref_encoder_outputs = self.ref_model.get_encoder()( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + return_dict=True, + **model_kwargs, + ) + # 2. Get reference decoder outputs + ref_outputs = self.ref_model.get_decoder()( + input_ids=model_kwargs["decoder_input_ids"], + encoder_hidden_states=ref_encoder_outputs.last_hidden_state, + use_cache=False, + **model_kwargs, + ) + else: + # skip the lm head and get the last hidden state + if hasattr(model, "get_decoder"): + base_model = model.get_decoder() + else: + base_model = getattr(model, self.args.base_model_attribute_name) + outputs = base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) + + # reference model + if hasattr(self.ref_model, "get_decoder"): + ref_base_model = self.ref_model.get_decoder() + else: + ref_base_model = getattr(self.ref_model, self.args.base_model_attribute_name) + ref_outputs = ref_base_model( + batch["completion_input_ids"], + attention_mask=batch["completion_attention_mask"], + use_cache=False, + **model_kwargs, + ) + lm_head = model.get_output_embeddings() + ref_lm_head = self.ref_model.get_output_embeddings() + + ( + loss, + ( + chosen_logps_sum, + rejected_logps_sum, + chosen_logits_sum, + rejected_logits_sum, + chosen_rewards_sum, + rejected_rewards_sum, + ), + ) = self.kto_loss_fn( + _input=outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + lin_weight=lm_head.weight, + target=batch["completion_labels"][:, 1:], + bias=lm_head.bias if hasattr(lm_head, "bias") else None, + preference_labels=torch.tensor(batch["label"], dtype=torch.bool).to(self.accelerator.device), + ref_input=ref_outputs.last_hidden_state[:, :-1] + if not self.is_encoder_decoder + else outputs.last_hidden_state, + ref_weight=ref_lm_head.weight, + ref_bias=ref_lm_head.bias if hasattr(lm_head, "bias") else None, + kl=kl, + ) + + output = { + "loss": loss, + "chosen_logits_sum": chosen_logits_sum, + "rejected_logits_sum": rejected_logits_sum, + "chosen_logps_sum": chosen_logps_sum, + "rejected_logps_sum": rejected_logps_sum, + "chosen_rewards_sum": chosen_rewards_sum, + "rejected_rewards_sum": rejected_rewards_sum, + "kl": kl, + } + if self.aux_loss_enabled: + output["aux_loss"] = outputs.aux_loss + + return output + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + ): + """Compute the KTO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + batch = {k: (v.to(self.accelerator.device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} + + labels = torch.tensor(batch["label"]) + num_chosen = labels.sum().to(self.accelerator.device) + num_rejected = (len(labels) - num_chosen).to(self.accelerator.device) + + if self.args.use_liger_loss: + model_output = self._compute_loss_liger(model, batch) + losses = model_output["loss"] + policy_chosen_logits = model_output["chosen_logits_sum"] + policy_rejected_logits = model_output["rejected_logits_sum"] + policy_chosen_logps = model_output["chosen_logps_sum"] + policy_rejected_logps = model_output["rejected_logps_sum"] + chosen_rewards = model_output["chosen_rewards_sum"] + rejected_rewards = model_output["rejected_rewards_sum"] + kl = model_output["kl"] + if self.aux_loss_enabled: + aux_loss = model_output["aux_loss"] + else: + forward_output = self.forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_KL_logps, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + # if reference_logps in batch use them, otherwise use the reference model + if "reference_logps" in batch: + chosen_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is True] + rejected_idx = [i for i in range(batch["reference_logps"].shape[0]) if batch["label"][i] is False] + + reference_chosen_logps = batch["reference_logps"][chosen_idx, ...] + reference_rejected_logps = batch["reference_logps"][rejected_idx, ...] + if self.calculate_KL: + reference_KL_logps = batch["reference_KL_logps"] + else: + reference_KL_logps = None + else: + with torch.no_grad(): + if self.ref_model is None: + with self.null_ref_context(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(self.model, batch)[:5] + else: + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + reference_KL_logps, + ) = self.forward(self.ref_model, batch)[:5] + + losses, chosen_rewards, rejected_rewards, kl = self.kto_loss( + policy_chosen_logps, + policy_rejected_logps, + policy_KL_logps, + reference_chosen_logps, + reference_rejected_logps, + reference_KL_logps, + ) + + metrics["kl"] = kl.item() + + all_num_chosen = self.accelerator.gather_for_metrics(num_chosen).sum().item() + all_num_rejected = self.accelerator.gather_for_metrics(num_rejected).sum().item() + + if all_num_chosen > 0: + metrics["rewards/chosen_sum"] = ( + self.accelerator.gather_for_metrics(chosen_rewards.nansum()).nansum().item() + ) + metrics["logps/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logps.nansum()).nansum().item() + ) + metrics["logits/chosen_sum"] = ( + self.accelerator.gather_for_metrics(policy_chosen_logits.nansum()).nansum().item() + ) + metrics["count/chosen"] = all_num_chosen + + if all_num_rejected > 0: + metrics["rewards/rejected_sum"] = ( + self.accelerator.gather_for_metrics(rejected_rewards.nansum()).nansum().item() + ) + metrics["logps/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logps.nansum()).nansum().item() + ) + metrics["logits/rejected_sum"] = ( + self.accelerator.gather_for_metrics(policy_rejected_logits.nansum()).nansum().item() + ) + metrics["count/rejected"] = all_num_rejected + + loss = losses.nanmean() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]: + if dataset is None: + dataset = self.train_dataset + if dataset is None or not has_length(dataset): + return None + return SequentialSampler(dataset) + + def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + # if reference_output in batch use that otherwise use the reference model + if "reference_output" in batch: + reference_output = batch["reference_output"] + else: + if self.ref_model is None: + with self.null_ref_context(): + reference_output = self.model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + else: + reference_output = self.ref_model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + reference_output = pad_to_length(reference_output, self.max_length, self.processing_class.pad_token_id) + reference_output_decoded = self.processing_class.batch_decode(reference_output, skip_special_tokens=True) + + return policy_output_decoded, reference_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ): + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs) + + # force log the metrics + if self.accelerator.is_main_process: + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = {} + if "logits/chosen_sum" in metrics: + logits_dict["eval_logits/chosen"] = metrics["logits/chosen_sum"] + if "logits/rejected_sum" in metrics: + logits_dict["eval_logits/rejected"] = metrics["logits/rejected_sum"] + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + target_indicies = [i for i in range(len(random_batch["label"])) if random_batch["label"][i] is False] + target_batch = { + "prompt_input_ids": random_batch["prompt_input_ids"][target_indicies], + "prompt_attention_mask": random_batch["prompt_attention_mask"][target_indicies], + "prompt": itemgetter(*target_indicies)(random_batch["prompt"]), + } + policy_output_decoded, ref_output_decoded = self.generate_from_model_and_ref(self.model, target_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy", "Ref Model"], + data=[ + [prompt, pol[len(prompt) :], ref[len(prompt) :]] + for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float` or `None`, *optional*, defaults to `None`): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # train metrics should have no prefix, eval should have 'eval_' + prefix = "eval_" if train_eval == "eval" else "" + # accumulate average metrics from sums and lengths + for split in ["chosen", "rejected"]: + if f"count/{split}" in self._stored_metrics[train_eval]: + count_sum = torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"]).sum().item() + for metric in ["rewards", "logps", "logits"]: + logs[f"{prefix}{metric}/{split}"] = ( + torch.Tensor(self._stored_metrics[train_eval][f"{metric}/{split}_sum"]).sum().item() + / count_sum + ) + # delete obsolete metric + del self._stored_metrics[train_eval][f"{metric}/{split}_sum"] + del self._stored_metrics[train_eval][f"count/{split}"] + # calculate reward margin + if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: + logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"] + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{ethayarajh2024kto, + title = {{KTO: Model Alignment as Prospect Theoretic Optimization}}, + author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela}, + year = 2024, + eprint = {arXiv:2402.01306}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="KTO", + trainer_citation=citation, + paper_title="KTO: Model Alignment as Prospect Theoretic Optimization", + paper_id="2402.01306", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/model_config.py b/trl/trainer/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..935bd73526dbdd8e91e0f18577903ee6319115ec --- /dev/null +++ b/trl/trainer/model_config.py @@ -0,0 +1,179 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelConfig: + """ + Configuration class for the models. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + model_name_or_path (`str` or `None`, *optional*, defaults to `None`): + Model checkpoint for weights initialization. + model_revision (`str`, *optional*, defaults to `"main"`): + Specific model version to use. It can be a branch name, a tag name, or a commit id. + torch_dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`): + Override the default `torch.dtype` and load the model under this dtype. Possible values are + + - `"bfloat16"`: `torch.bfloat16` + - `"float16"`: `torch.float16` + - `"float32"`: `torch.float32` + - `"auto"`: Automatically derive the dtype from the model's weights. + + trust_remote_code (`bool`, *optional*, defaults to `False`): + Whether to allow for custom models defined on the Hub in their own modeling files. This option should only + be set to `True` for repositories you trust and in which you have read the code, as it will execute code + present on the Hub on your local machine. + attn_implementation (`str` or `None`, *optional*, defaults to `None`): + Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case + you must install this manually by running `pip install flash-attn --no-build-isolation`. + use_peft (`bool`, *optional*, defaults to `False`): + Whether to use PEFT for training. + lora_r (`int`, *optional*, defaults to `16`): + LoRA R value. + lora_alpha (`int`, *optional*, defaults to `32`): + LoRA alpha. + lora_dropout (`float`, *optional*, defaults to `0.05`): + LoRA dropout. + lora_target_modules (`Union[str, list[str]]` or `None`, *optional*, defaults to `None`): + LoRA target modules. + lora_modules_to_save (`list[str]` or `None`, *optional*, defaults to `None`): + Model layers to unfreeze & train. + lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`): + Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling). + use_rslora (`bool`, *optional*, defaults to `False`): + Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of + the original default value of `lora_alpha/r`. + use_dora (`bool`, *optional*, defaults to `False`): + Enable [Weight-Decomposed Low-Rank Adaptation (DoRA)](https://huggingface.co/papers/2402.09353). This + technique decomposes the updates of the weights into two parts, magnitude and direction. Direction is + handled by normal LoRA, whereas the magnitude is handled by a separate learnable parameter. This can + improve the performance of LoRA, especially at low ranks. Right now, DoRA only supports linear and Conv2D + layers. DoRA introduces a bigger overhead than pure LoRA, so it is recommended to merge weights for + inference. + load_in_8bit (`bool`, *optional*, defaults to `False`): + Whether to use 8 bit precision for the base model. Works only with LoRA. + load_in_4bit (`bool`, *optional*, defaults to `False`): + Whether to use 4 bit precision for the base model. Works only with LoRA. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`): + Quantization type (`"fp4"` or `"nf4"`). + use_bnb_nested_quant (`bool`, *optional*, defaults to `False`): + Whether to use nested quantization. + """ + + model_name_or_path: Optional[str] = field( + default=None, + metadata={"help": "Model checkpoint for weights initialization."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."}, + ) + torch_dtype: Optional[str] = field( + default=None, + metadata={ + "help": "Override the default `torch.dtype` and load the model under this dtype.", + "choices": ["auto", "bfloat16", "float16", "float32"], + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": "Whether to allow for custom models defined on the Hub in their own modeling files. This option " + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + }, + ) + attn_implementation: Optional[str] = field( + default=None, + metadata={ + "help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in " + "which case you must install this manually by running `pip install flash-attn --no-build-isolation`." + }, + ) + use_peft: bool = field( + default=False, + metadata={"help": "Whether to use PEFT for training."}, + ) + lora_r: int = field( + default=16, + metadata={"help": "LoRA R value."}, + ) + lora_alpha: int = field( + default=32, + metadata={"help": "LoRA alpha."}, + ) + lora_dropout: float = field( + default=0.05, + metadata={"help": "LoRA dropout."}, + ) + lora_target_modules: Optional[list[str]] = field( + default=None, + metadata={"help": "LoRA target modules."}, + ) + lora_modules_to_save: Optional[list[str]] = field( + default=None, + metadata={"help": "Model layers to unfreeze & train."}, + ) + lora_task_type: str = field( + default="CAUSAL_LM", + metadata={"help": "Task type to pass for LoRA (use 'SEQ_CLS' for reward modeling)."}, + ) + use_rslora: bool = field( + default=False, + metadata={ + "help": "Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, " + "instead of the original default value of `lora_alpha/r`." + }, + ) + use_dora: bool = field( + default=False, + metadata={ + "help": "Enable Weight-Decomposed Low-Rank Adaptation (DoRA). This technique decomposes the updates of " + "the weights into two parts, magnitude and direction. Direction is handled by normal LoRA, whereas the " + "magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, " + "especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a " + "bigger overhead than pure LoRA, so it is recommended to merge weights for inference." + }, + ) + load_in_8bit: bool = field( + default=False, + metadata={"help": "Whether to use 8 bit precision for the base model. Works only with LoRA."}, + ) + load_in_4bit: bool = field( + default=False, + metadata={"help": "Whether to use 4 bit precision for the base model. Works only with LoRA."}, + ) + bnb_4bit_quant_type: str = field( + default="nf4", + metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]}, + ) + use_bnb_nested_quant: bool = field( + default=False, + metadata={"help": "Whether to use nested quantization."}, + ) + + def __post_init__(self): + if self.load_in_8bit and self.load_in_4bit: + raise ValueError("You can't use 8 bit and 4 bit precision at the same time") + + if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1: + self.lora_target_modules = self.lora_target_modules[0] diff --git a/trl/trainer/nash_md_config.py b/trl/trainer/nash_md_config.py new file mode 100644 index 0000000000000000000000000000000000000000..07d8152f4fae45c9cf16b1815b3a225a9237c695 --- /dev/null +++ b/trl/trainer/nash_md_config.py @@ -0,0 +1,46 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from trl.trainer.online_dpo_config import OnlineDPOConfig + + +@dataclass +class NashMDConfig(OnlineDPOConfig): + r""" + Configuration class for the [`NashMDTrainer`]. + + Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): + Logit mixture coefficient for the model and reference model. If a list of floats is provided then the + mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the + epochs. + """ + + mixture_coef: list[float] = field( + default_factory=lambda: [0.5], + metadata={ + "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " + "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " + "rest of the epochs." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: + self.mixture_coef = self.mixture_coef[0] diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..07825cb855e89060fc171b11f40c999f276d9972 --- /dev/null +++ b/trl/trainer/nash_md_trainer.py @@ -0,0 +1,545 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +from typing import Any, Callable, Optional, Union + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_wandb_available, +) +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_apex_available, is_peft_available + +from ..data_utils import is_conversational, maybe_apply_chat_template +from ..models.modeling_base import GeometricMixtureWrapper +from ..models.utils import unwrap_model_for_generation +from .judges import BasePairwiseJudge +from .nash_md_config import NashMDConfig +from .online_dpo_trainer import OnlineDPOTrainer +from .utils import ( + SIMPLE_CHAT_TEMPLATE, + empty_cache, + generate_model_card, + get_comet_experiment_url, + get_reward, + selective_log_softmax, + truncate_right, +) + + +if is_apex_available(): + from apex import amp + + +if is_wandb_available(): + import wandb + + +if is_peft_available(): + from peft import PeftModel + + +class NashMDTrainer(OnlineDPOTrainer): + r""" + Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`]. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + reward_model (`transformers.PreTrainedModel`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + judge (`BasePairwiseJudge`): + The judge to use for pairwise comparison of model completions. + args (`NashMDConfig`): + The NashMD config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "nash-md"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Union[PreTrainedModel, nn.Module] = None, + reward_model: Union[PreTrainedModel, nn.Module, None] = None, + judge: Optional[BasePairwiseJudge] = None, + args: Optional[NashMDConfig] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + reward_model=reward_model, + judge=judge, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._mixture_coef = self.args.mixture_coef + + # Overwrite the stats dictionary to include NashMD specific statistics + self.stats = { + # Remove "non_score_reward", "rlhf_reward", "scores_margin" + # Add "mixture_coef" + "loss/kl": [], + "objective/entropy": [], + "loss/score": [], + "rewards/probabilities": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "beta": [], + "mixture_coef": [], + } + if self.reward_model is not None: + self.stats["rewards/chosen"] = [] + self.stats["rewards/rejected"] = [] + + @property + def mixture_coef(self): + if isinstance(self._mixture_coef, list): + epoch = self.state.epoch + return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] + else: + return self._mixture_coef + + def _generate_completions(self, model, prompts): + # Generate completions from the policy model. + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx: + model_output = unwrapped_policy_for_gen_ctx.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + # Get the DDP/FSDP unwrapped version of the main model. + # This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used). + policy_model_for_gmw = self.accelerator.unwrap_model(model) + + # Determine the correct reference model for GeometricMixtureWrapper. + # This also needs to be DDP/FSDP unwrapped. + ref_model_for_gmw: torch.nn.Module + if self.ref_model is None: + # No explicit ref_model is provided. + # Use the base of the main `model` if it's a PEFT model. + # policy_model_for_gmw is already DDP-unwrapped. + if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel): + ref_model_for_gmw = policy_model_for_gmw.get_base_model() + else: + # Not a PEFT model (or PEFT not available), or already a base model. + # Use the DDP-unwrapped policy model itself as the reference. + ref_model_for_gmw = policy_model_for_gmw + else: + # An explicit ref_model is provided. Unwrap it for DDP/FSDP. + ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model) + + # Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped. + with torch.no_grad(): # Ensure no_grad context for mixture model generation + mixture_model = GeometricMixtureWrapper( + model=policy_model_for_gmw, + ref_model=ref_model_for_gmw, + generation_config=self.generation_config, + mixture_coef=self.mixture_coef, + device=self.accelerator.device, + ) + + mixture_output = mixture_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, mixture_output + + def _process_completions(self, model_output, mixture_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + "raw": prompts["raw"], + } + + # Process reference model completions + mixture_completion_ids = mixture_output[:, context_length:] + mixture_completion_ids, mixture_completion_mask = truncate_right( + mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + mixture_data = { + "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), + "raw": prompts["raw"], + } + + return model_data, mixture_data + + def _compute_rewards(self, model_data, mixture_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + _, mixture_scores, _ = get_reward( + self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, mixture_scores + + def _compute_judge(self, model_data, mixture_data, context_length): + prompts = model_data["raw"] + model_data_completions = self.processing_class.batch_decode( + model_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + model_data_completions = [completion.strip() for completion in model_data_completions] + + mixture_data_completions = self.processing_class.batch_decode( + mixture_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + mixture_data_completions = [completion.strip() for completion in mixture_data_completions] + if is_conversational({"prompt": prompts[0]}): + model_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in model_data_completions + ] + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=message) for message in prompts] + model_data_completions = [template.render(messages=completion) for completion in model_data_completions] + + mixture_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in mixture_data_completions + ] + mixture_data_completions = [ + template.render(messages=completion) for completion in mixture_data_completions + ] + + probability = self.judge.judge( + prompts, + list(zip(model_data_completions, mixture_data_completions)), + return_scores=True, + ) + return torch.tensor(probability, device=model_data["input_ids"].device) + + def _compute_logprobs(self, model, model_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) + return token_logprobs + + # Compute logprobs for model completions under the model + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + + # Compute logprobs of model completions under the reference model + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return (model_logprobs_model_data, ref_logprobs_model_data) + + def _compute_losses( + self, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + ): + # reinforce score where 0.5 is a control variate + score = (probability - 0.5) * model_logprobs_model_data.sum(1) + + # kl divergence via reinforce + with torch.no_grad(): + log_ratio = model_logprobs_model_data - ref_logprobs_model_data + kl_div_log = log_ratio.sum(1) + kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1) + + # final loss + loss = self.beta * kl_div_loss - score + + return loss.mean(), score, kl_div_log + + def _log_statistics( + self, + model_data, + mixture_data, + model_logprobs_model_data, + ref_logprobs_model_data, + probability, + score, + kl_div, + context_length, + model_scores=None, + mixture_scores=None, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather_for_metrics(tensor).mean().item() + + # Log score + self.stats["loss/score"].append(gather_mean(score)) + # Log KL divergence + self.stats["loss/kl"].append(gather_mean(kl_div)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) + self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) + + # Log rewards + if self.reward_model is not None: + self.stats["rewards/chosen"].append(gather_mean(model_scores)) + self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) + + # Log probabilities + self.stats["rewards/probabilities"].append(gather_mean(probability)) + + # Calculate entropy for model data + entropy_model_data = -model_logprobs_model_data.sum(1) + self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) + + # Calculate margins + margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum + self.stats["rewards/margins"].append(gather_mean(margin)) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy)) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) + + # Log beta and mixture coef + self.stats["beta"].append(self.beta) + self.stats["mixture_coef"].append(self.mixture_coef) + + def training_step( + self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: + model.train() + + # Apply chat template and tokenize the input + batch_size = len(next(iter(inputs.values()))) + prompts = inputs["prompt"] + inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + "raw": prompts, + } + del inputs + + # Sample completions from both the model and the reference model + model_output, mixture_output = self._generate_completions(model, prompts) + + # Process model completions + model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) + + # Compute rewards + if self.reward_model is not None: + model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length) + # probability of the model data vs the mixture data + probability = F.sigmoid(model_scores - mixture_scores) + else: + model_scores, mixture_scores = None, None + probability = self._compute_judge(model_data, mixture_data, context_length) + + # Compute logprobs + model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) + + # Compute loss + loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability) + + # Log everything + self._log_statistics( + model_data, + mixture_data, + model_logprobs_model_data.detach(), + ref_logprobs_model_data, + probability, + score.detach(), + kl_div.detach(), + context_length, + model_scores, + mixture_scores, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @inproceedings{munos2024nash, + title = {{Nash Learning from Human Feedback}}, + author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=Y5AmNYiyCQ} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="Nash-MD", + trainer_citation=citation, + paper_title="Nash Learning from Human Feedback", + paper_id="2312.00886", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..ce8c8ffb94f2ca524379fd2aca1e0beca0647d24 --- /dev/null +++ b/trl/trainer/online_dpo_config.py @@ -0,0 +1,185 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class OnlineDPOConfig(TrainingArguments): + r""" + Configuration class for the [`OnlineDPOTrainer`]. + + This class includes only the parameters that are specific to Online DPO training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + reward_model_path (`str` or `None`, *optional*, defaults to `None`): + Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both. + judge (`str` or `None`, *optional*, defaults to `None`): + Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both. + max_new_tokens (`int`, *optional*, defaults to `64`): + Maximum number of tokens to generate per completion. + max_length (`int`, *optional*, defaults to `256`): + Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the + sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as + possible. + temperature (`float`, *optional*, defaults to `0.9`): + Temperature for sampling. The higher the temperature, the more random the completions. + missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage + to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. + beta (`float` or `list[float]`, *optional*, defaults to `0.1`): + Parameter controlling the deviation from the reference model. Higher β means less deviation from the + reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in + the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is + selected for each new epoch and the last β is used for the rest of the epochs. + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model and reference model. + use_vllm (`bool`, *optional*, defaults to `False`): + Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). + gpu_memory_utilization (`float`, *optional*, defaults to `0.55`): + The vLLM memory utilization. The default value is 0.55. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. + """ + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=5e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + reward_model_path: Optional[str] = field( + default=None, + metadata={ + "help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both." + }, + ) + judge: Optional[str] = field( + default=None, + metadata={ + "help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both." + }, + ) + max_new_tokens: int = field( + default=64, + metadata={"help": "Maximum number of tokens to generate per completion."}, + ) + max_length: int = field( + default=512, + metadata={ + "help": "Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If " + "the sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the " + "completion as possible." + }, + ) + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + missing_eos_penalty: Optional[float] = field( + default=None, + metadata={ + "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " + "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " + "a positive value." + }, + ) + beta: list[float] = field( + default_factory=lambda: [0.1], + metadata={ + "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " + "the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by " + "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β " + "is selected for each new epoch and the last β is used for the rest of the epochs." + }, + ) + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": ["sigmoid", "ipo"], + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " + "(`pip install vllm`)." + }, + ) + gpu_memory_utilization: Optional[float] = field( + default=0.55, + metadata={ + "help": "The vLLM memory utilization. The default value is 0.55.", + }, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.beta, "__len__") and len(self.beta) == 1: + self.beta = self.beta[0] diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..aad73ad3f3ec891757c584628d7128b201a36c02 --- /dev/null +++ b/trl/trainer/online_dpo_trainer.py @@ -0,0 +1,803 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +import warnings +from functools import wraps +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import datasets +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data +from datasets import Dataset +from packaging import version +from torch.utils.data import DataLoader, IterableDataset +from transformers import ( + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + is_apex_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalPrediction, seed_worker +from transformers.training_args import OptimizerNames +from transformers.utils import is_peft_available, is_sagemaker_mp_enabled, logging + +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from ..import_utils import is_vllm_available +from ..models import create_reference_model +from ..models.utils import unwrap_model_for_generation +from .judges import BasePairwiseJudge +from .online_dpo_config import OnlineDPOConfig +from .utils import ( + SIMPLE_CHAT_TEMPLATE, + DPODataCollatorWithPadding, + disable_dropout_in_model, + empty_cache, + generate_model_card, + get_comet_experiment_url, + get_reward, + prepare_deepspeed, + truncate_right, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model + +if is_apex_available(): + from apex import amp + + +if is_sagemaker_mp_enabled(): + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_vllm_available(): + from vllm import LLM, SamplingParams + +if is_wandb_available(): + import wandb + +logger = logging.get_logger(__name__) + + +class OnlineDPOTrainer(Trainer): + r""" + Initialize OnlineDPOTrainer. + + Args: + model (`transformers.PreTrainedModel` or `torch.nn.Module`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): + The reference model to use for training. If None is specified, the reference model will be created from + the model. + reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + judge (`BasePairwiseJudge`): + The judge to use for pairwise comparison of model completions. + args (`OnlineDPOConfig`): + The online DPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "online-dpo"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module], + ref_model: Union[PreTrainedModel, nn.Module, None] = None, + reward_model: Union[PreTrainedModel, nn.Module, None] = None, + judge: Optional[BasePairwiseJudge] = None, + args: Optional[OnlineDPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset], "datasets.Dataset"]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + reward_processing_class: Optional[PreTrainedTokenizerBase] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, either omit the `ref_model` argument or pass `None`." + ) + + self.ref_model = ref_model + + if reward_model is not None and judge is not None: + warnings.warn( + "Both `reward_model` and `judge` are provided. Please choose provide only one of them. " + "Ignoring `judge` and using `reward_model`.", + UserWarning, + ) + judge = None + elif reward_model is None and judge is None: + raise ValueError("Either `reward_model` or `judge` must be provided.") + + self.reward_model = reward_model + self.reward_processing_class = reward_processing_class + self.judge = judge + self.is_encoder_decoder = model.config.is_encoder_decoder + + if args.missing_eos_penalty is not None and judge is not None: + raise ValueError("`missing_eos_penalty` is not supported when `judge` is provided.") + + if args is None: + raise ValueError("`args` must be provided.") + + # Check that the processing_class is provided + if processing_class is None: + raise ValueError("`processing_class` must be provided.") + + # Convert to PEFT model if peft_config is provided + if peft_config is not None: + # Check if PEFT is available + if not is_peft_available(): + raise ImportError( + "PEFT is not available and passed `peft_config`. Please install PEFT with " + "`pip install peft` to use it." + ) + + # If the model is already a PeftModel, we need to merge and unload it. + # Further information here: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + # Get peft model with the given config + model = get_peft_model(model, peft_config) + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + if self.ref_model is not None: + disable_dropout_in_model(self.ref_model) + + # Handle the ref_model + # Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to + # get the ref model, as it's just the model with a disabled adapter. When not using PEFT, we need to create + # the ref model from the model by copying it and disable the gradients and set it in evaluation mode. + if ref_model is None: # No ref model provided, the most common case + if peft_config is None: + self.ref_model = create_reference_model(model) # copy, disable gradients, set eval mode + else: + self.ref_model = None # we don't need a ref model here, we can just disable the adapter. + else: # rare case, the user provided a ref model + self.ref_model = ref_model + self.ref_model.eval() + + # Disable the gradient and set the reward model in eval mode + if self.reward_model is not None: + self.reward_model.eval() + + # Define the collator is not provided + if data_collator is None: + data_collator = DPODataCollatorWithPadding(pad_token_id=processing_class.pad_token_id) + + self.max_length = args.max_length + + self.stats = { + "objective/kl": [], + "objective/entropy": [], + "objective/non_score_reward": [], + "rewards/chosen": [], + "rewards/rejected": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + "val/contain_eos_token": [], + "beta": [], + } + if self.reward_model is not None: + self.stats["objective/rlhf_reward"] = [] + self.stats["objective/scores_margin"] = [] + self.stats["objective/scores"] = [] + + if args.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + self.generation_config = SamplingParams( + n=2, # 2 generations per prompt + max_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=50, + top_p=1.0, + detokenize=False, # to avoid vllm to decode (we don't need it) + ) + # vLLM dynamically adjusts the size of the key-value cache based on available GPU memory at instantiation. + # A larger cache size improves speed, so we would expect gpu_memory_utilization=1. + # However, at this stage, the optimizer's weights are not yet loaded onto the GPU; they will be loaded + # after the first optimizer step and remain in GPU memory throughout training. So we must reserve enough + # space for them. Setting gpu_memory_utilization to 0.55 seems to work well in practice. + self.llm = LLM( + model=model.name_or_path, + gpu_memory_utilization=args.gpu_memory_utilization, + dtype=torch.float32, + # When release by vLLM, we would be able to distribute the model on multiple GPUs + # See https://github.com/vllm-project/vllm/pull/12071 + # tensor_parallel_size=torch.cuda.device_count(), + # distributed_executor_backend="external_launcher", + ) + else: + self.generation_config = GenerationConfig( + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=50, + top_p=1.0, + do_sample=True, + use_cache=False if args.gradient_checkpointing else True, + ) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in Online DPO, the sampled data does not include + # the "input_ids" key. As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + self._beta = args.beta + + # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator + if self.is_deepspeed_enabled: + if self.reward_model is not None: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + if self.ref_model is not None: + self.ref_model = prepare_deepspeed( + self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + if self.ref_model is not None: + self.ref_model = self.ref_model.to(self.accelerator.device) + if self.reward_model is not None: + self.reward_model = self.reward_model.to(self.accelerator.device) + + @property + def beta(self): + if isinstance(self._beta, list): + epoch = self.state.epoch + return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] + else: + return self._beta + + @staticmethod + def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> dict[str, Any]: + """Tokenize a single row from a DPO specific dataset.""" + if not is_encoder_decoder: + batch = tokenizer(feature["prompt"], add_special_tokens=False) + # Add BOS token to head of prompt. Avoid adding if it's already there + if tokenizer.bos_token_id is not None: + prompt_len_input_ids = len(batch["input_ids"]) + if prompt_len_input_ids == 0 or tokenizer.bos_token_id != batch["input_ids"][0]: + batch["input_ids"] = [tokenizer.bos_token_id] + batch["input_ids"] + batch["attention_mask"] = [1] + batch["attention_mask"] + else: + batch = tokenizer(feature["prompt"], add_special_tokens=True) + batch = {f"prompt_{key}": value for key, value in batch.items()} + return batch + + # Same as Trainer.get_train_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_train_dataloader) + def get_train_dataloader(self) -> DataLoader: + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) + + # Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns". + @wraps(Trainer.get_eval_dataloader) + def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader: + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + + # If we have persistent workers, don't do a fork bomb especially as eval datasets + # don't change during training + dataloader_key = eval_dataset if isinstance(eval_dataset, str) else "eval" + if ( + hasattr(self, "_eval_dataloaders") + and dataloader_key in self._eval_dataloaders + and self.args.dataloader_persistent_workers + ): + return self.accelerator.prepare(self._eval_dataloaders[dataloader_key]) + + eval_dataset = ( + self.eval_dataset[eval_dataset] + if isinstance(eval_dataset, str) + else eval_dataset + if eval_dataset is not None + else self.eval_dataset + ) + data_collator = self.data_collator + + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + "persistent_workers": self.args.dataloader_persistent_workers, + } + + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + # accelerator.free_memory() will destroy the references, so + # we need to store the non-prepared version + eval_dataloader = DataLoader(eval_dataset, **dataloader_params) + if self.args.dataloader_persistent_workers: + if hasattr(self, "_eval_dataloaders"): + self._eval_dataloaders[dataloader_key] = eval_dataloader + else: + self._eval_dataloaders = {dataloader_key: eval_dataloader} + + return self.accelerator.prepare(eval_dataloader) + + def _generate_vllm(self, model, prompts): + eos_token_id = self.processing_class.eos_token_id + pad_token_id = self.processing_class.pad_token_id + + # Load the latest weights + llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model + llm_model.load_weights(model.state_dict().items()) + + if is_conversational({"prompt": prompts[0]}): + outputs = self.llm.chat(prompts, self.generation_config, use_tqdm=False) + else: + outputs = self.llm.generate(prompts, self.generation_config, use_tqdm=False) + + completion_ids = [list(output.outputs[i].token_ids) for i in range(2) for output in outputs] + prompt_ids = [list(output.prompt_token_ids) for _ in range(2) for output in outputs] + + # Create mask and pad the prompt and completion + max_prompt_length = max(len(ids) for ids in prompt_ids) + prompt_mask = [[0] * (max_prompt_length - len(ids)) + [1] * len(ids) for ids in prompt_ids] + prompt_ids = [[pad_token_id] * (max_prompt_length - len(ids)) + ids for ids in prompt_ids] + max_tokens = self.generation_config.max_tokens + completion_mask = [[1] * len(ids) + [0] * (max_tokens - len(ids)) for ids in completion_ids] + completion_ids = [ + ids + [eos_token_id] if ids[-1] != eos_token_id and len(ids) < max_tokens else ids + for ids in completion_ids + ] + completion_ids = [ids + [pad_token_id] * (max_tokens - len(ids)) for ids in completion_ids] + + # Convert to tensors + prompt_ids = torch.tensor(prompt_ids, device=self.accelerator.device) + prompt_mask = torch.tensor(prompt_mask, device=self.accelerator.device) + completion_ids = torch.tensor(completion_ids, device=self.accelerator.device) + completion_mask = torch.tensor(completion_mask, device=self.accelerator.device) + + return prompt_ids, prompt_mask, completion_ids, completion_mask + + def _generate(self, model, prompts): + eos_token_id = self.processing_class.eos_token_id + pad_token_id = self.processing_class.pad_token_id + + # Apply chat template and tokenize the input. We do this on-the-fly to enable the use of reward models and + # policies with different tokenizers / chat templates. + inputs = [{"prompt": prompt} for prompt in prompts] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # Sample 2 completions per prompt of size `max_new_tokens` from the model + inputs = self._prepare_inputs(inputs) + prompt_ids = inputs["prompt_input_ids"].repeat(2, 1) + prompt_mask = inputs["prompt_attention_mask"].repeat(2, 1) + with unwrap_model_for_generation( + model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + output = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + generation_config=self.generation_config, + ) + + completion_ids = output[:, prompt_ids.size(1) :] + completion_ids, completion_mask = truncate_right(completion_ids, eos_token_id, pad_token_id) + + return prompt_ids, prompt_mask, completion_ids, completion_mask + + def _forward(self, model, prompt_ids, prompt_mask, completion_ids, completion_mask): + # Get the number of tokens to truncate from prompt + num_tokens_to_truncate = max(prompt_ids.size(1) + completion_ids.size(1) - self.max_length, 0) + + # Truncate left to avoid oom + prompt_ids = prompt_ids[:, num_tokens_to_truncate:] + prompt_mask = prompt_mask[:, num_tokens_to_truncate:] + + # Concat the prompt and completion + prompt_completion_ids = torch.cat((prompt_ids, completion_ids), dim=1) + prompt_completion_mask = torch.cat((prompt_mask, completion_mask), dim=1) + + # Get the logprobs of the completions from the model + output = model(prompt_completion_ids, attention_mask=prompt_completion_mask) + + # There is 1 offset, because the model predict the next token + logits = output.logits[:, prompt_ids.size(1) - 1 : -1] + + # Take the completion tokens logprob + logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1) + return logprobs + + def training_step( + self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: + model.train() + + prompts = inputs["prompt"] + batch_size = len(prompts) + + if self.args.use_vllm: + prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate_vllm(model, prompts) + else: + prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts) + + contain_eos_token = torch.any(completion_ids == self.processing_class.eos_token_id, dim=-1) + + logprobs = self._forward(model, prompt_ids, prompt_mask, completion_ids, completion_mask) + with torch.no_grad(): + if self.ref_model is not None: + ref_logprobs = self._forward(self.ref_model, prompt_ids, prompt_mask, completion_ids, completion_mask) + else: # peft case: we just need to disable the adapter + with self.model.disable_adapter(): + ref_logprobs = self._forward(self.model, prompt_ids, prompt_mask, completion_ids, completion_mask) + + # Decode the completions, and format them if the input is conversational + device = logprobs.device + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational({"prompt": prompts[0]}): + completions = [[{"role": "assistant", "content": completion}] for completion in completions] + + # Get the reward from the reward model or judge + if self.judge is not None: + # Once formatted, conversational data may contain special tokens (such as <|im_start|>) that are not + # directly understandable by the judge and could alter its judgment. To avoid this and make the judge + # independent of the model's chat template, we use the raw conversation data, and apply our own chat + # template to it. + if is_conversational({"prompt": prompts[0]}): + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=prompt) for prompt in prompts] + completions = [template.render(messages=completion) for completion in completions] + + ranks_of_first_completion = self.judge.judge( + prompts, list(zip(completions[:batch_size], completions[batch_size:])) + ) + + # convert ranks to a True/False mask: + # when rank == 0, it means the first completion is the best + # when rank == 1, it means the second completion is the best + mask = torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=device) + else: + # The reward model may not have the same chat template or tokenizer as the model, so we need to use the + # raw data (string), apply the chat template (if needed), and tokenize it with the reward processing class. + prompts = 2 * prompts # repeat the prompt: [prompt0, prompt1] -> [prompt0, prompt1, prompt0, prompt1] + if is_conversational({"prompt": prompts[0]}): + examples = [{"prompt": p, "completion": c} for p, c in zip(prompts, completions)] + examples = [apply_chat_template(example, self.reward_processing_class) for example in examples] + prompts = [example["prompt"] for example in examples] + completions = [example["completion"] for example in examples] + + # Tokenize the prompts + prompts_ids = self.reward_processing_class( + prompts, padding=True, return_tensors="pt", padding_side="left" + )["input_ids"].to(device) + context_length = prompts_ids.shape[1] + + # Tokenize the completions + completions_ids = self.reward_processing_class( + completions, padding=True, return_tensors="pt", padding_side="right" + )["input_ids"].to(device) + + # Concatenate the prompts and completions and get the reward + prompt_completion_ids = torch.cat((prompts_ids, completions_ids), dim=1) + with torch.inference_mode(): + _, scores, _ = get_reward( + self.reward_model, prompt_completion_ids, self.reward_processing_class.pad_token_id, context_length + ) + + # Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + + # Split the scores in 2 (the prompts of the first half are the same as the second half) + first_half, second_half = scores.split(batch_size) + + # Get the indices of the chosen and rejected examples + mask = first_half >= second_half + + batch_range = torch.arange(batch_size, device=device) + chosen_indices = batch_range + (~mask * batch_size) + rejected_indices = batch_range + (mask * batch_size) + + # Build tensor so that the first half is the chosen examples and the second half the rejected examples + cr_indices = torch.cat((chosen_indices, rejected_indices), dim=0) # cr = chosen and rejected + cr_logprobs = logprobs[cr_indices] + cr_ref_logprobs = ref_logprobs[cr_indices] + + # mask out the padding tokens + padding_mask = ~completion_mask.bool() + cr_padding_mask = padding_mask[cr_indices] + + cr_logprobs_sum = (cr_logprobs * ~cr_padding_mask).sum(1) + cr_ref_logprobs_sum = (cr_ref_logprobs * ~cr_padding_mask).sum(1) + + # Split the chosen and rejected examples + chosen_logprobs_sum, rejected_logprobs_sum = torch.split(cr_logprobs_sum, batch_size) + chosen_ref_logprobs_sum, rejected_ref_logprobs_sum = torch.split(cr_ref_logprobs_sum, batch_size) + pi_logratios = chosen_logprobs_sum - rejected_logprobs_sum + ref_logratios = chosen_ref_logprobs_sum - rejected_ref_logprobs_sum + + logits = pi_logratios - ref_logratios + + if self.args.loss_type == "sigmoid": + losses = -F.logsigmoid(self.beta * logits) + elif self.args.loss_type == "ipo": + losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.loss_type}") + + loss = losses.mean() + + # Log everything + if self.reward_model is not None: + scores_margin = scores[chosen_indices] - scores[rejected_indices] + self.stats["objective/scores_margin"].append( + self.accelerator.gather_for_metrics(scores_margin.mean()).mean().item() + ) + self.stats["objective/scores"].append(self.accelerator.gather_for_metrics(scores.mean()).mean().item()) + self.stats["val/contain_eos_token"].append(contain_eos_token.float().mean().item()) + self.stats["logps/chosen"].append(self.accelerator.gather_for_metrics(chosen_logprobs_sum).mean().item()) + self.stats["logps/rejected"].append(self.accelerator.gather_for_metrics(rejected_logprobs_sum).mean().item()) + + kl = logprobs - ref_logprobs + mean_kl = kl.sum(1).mean() + self.stats["objective/kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + non_score_reward = (-self.beta * kl).sum(1) + mean_non_score_reward = non_score_reward.mean() + self.stats["objective/non_score_reward"].append( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + if self.reward_model is not None: + rlhf_reward = scores + non_score_reward + self.stats["objective/rlhf_reward"].append(self.accelerator.gather_for_metrics(rlhf_reward).mean().item()) + mean_entropy = -logprobs.sum(1).mean() + self.stats["objective/entropy"].append(self.accelerator.gather_for_metrics(mean_entropy).mean().item()) + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + gathered_chosen_rewards = self.accelerator.gather_for_metrics(chosen_rewards) + self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + gathered_rejected_rewards = self.accelerator.gather_for_metrics(rejected_rewards) + self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) + margin = gathered_chosen_rewards - gathered_rejected_rewards + self.stats["rewards/margins"].append(margin.mean().item()) + accuracy = margin > 0 + self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) + self.stats["beta"].append(self.beta) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + + # For LOMO optimizers you need to explicitly use the learnign rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps + + # Same as Trainer._maybe_log_save_evaluate but log our metrics + def _maybe_log_save_evaluate( + self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time, learning_rate=None + ): + if self.control.should_log and self.state.global_step > self._globalstep_last_logged: + logs: dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + if grad_norm is not None: + logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm + if learning_rate is not None: + logs["learning_rate"] = learning_rate + else: + logs["learning_rate"] = self._get_learning_rate() + + # Add our metrics + for key, val in self.stats.items(): + logs[key] = sum(val) / len(val) + self.stats = {key: [] for key in self.stats} # reset stats + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + self.log(logs, start_time) + + metrics = None + if self.control.should_evaluate: + metrics = self._evaluate(trial, ignore_keys_for_eval) + is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) + + if self.args.save_strategy == "best": + self.control.should_save = is_new_best_metric + + if self.control.should_save: + self._save_checkpoint(model, trial) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{guo2024direct, + title = {{Direct Language Model Alignment from Online AI Feedback}}, + author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel}, + year = 2024, + eprint = {arXiv:2402.04792} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="Online DPO", + trainer_citation=citation, + paper_title="Direct Language Model Alignment from Online AI Feedback", + paper_id="2402.04792", + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/orpo_config.py b/trl/trainer/orpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..19e733126d9cab79b2591c8cf873eb8e1dbd2899 --- /dev/null +++ b/trl/trainer/orpo_config.py @@ -0,0 +1,160 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class ORPOConfig(TrainingArguments): + r""" + Configuration class for the [`ORPOTrainer`]. + + This class includes only the parameters that are specific to ORPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want + to use the default data collator. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt. This argument is required if you want to use the default data collator. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion. This argument is required if you want to use the default data collator + and your model is an encoder-decoder. + beta (`float`, *optional*, defaults to `0.1`): + Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691), + it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + label_pad_token_id (`int`, *optional*, defaults to `-100`): + Label pad token id. This argument is required if you want to use the default data collator. + padding_value (`int` or `None`, *optional*, defaults to `None`): + Padding value to use. If `None`, the padding value of the tokenizer is used. + truncation_mode (`str`, *optional*, defaults to `"keep_end"`): + Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. + This argument is required if you want to use the default data collator. + generate_during_eval (`bool`, *optional*, defaults to `False`): + If `True`, generates and logs completions from the model to W&B or Comet during evaluation. + is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): + When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, + you need to specify if the model returned by the callable is an encoder-decoder model. + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a + string. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-6, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion. This argument is required if you want to use the default data " + "collator and your model is an encoder-decoder." + }, + ) + beta: float = field( + default=0.1, + metadata={ + "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " + "denoted by λ." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + label_pad_token_id: int = field( + default=-100, + metadata={ + "help": "Label pad token id. This argument is required if you want to use the default data collator." + }, + ) + padding_value: Optional[int] = field( + default=None, + metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, + ) + truncation_mode: str = field( + default="keep_end", + metadata={ + "help": "Truncation mode to use when the prompt is too long.", + "choices": ["keep_end", "keep_start"], + }, + ) + generate_during_eval: bool = field( + default=False, + metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, + ) + is_encoder_decoder: Optional[bool] = field( + default=None, + metadata={ + "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " + "argument, you need to specify if the model returned by the callable is an encoder-decoder model." + }, + ) + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " + "from a string." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..56a19bcacb0b171d2274a73345651b166bbe376e --- /dev/null +++ b/trl/trainer/orpo_trainer.py @@ -0,0 +1,1076 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import random +import textwrap +import warnings +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Callable, Literal, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate import PartialState +from datasets import Dataset +from torch import autocast +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForCausalLM, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_comet_available, + is_torch_xla_available, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalLoopOutput +from transformers.utils import is_peft_available, is_torch_fx_proxy + +from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt +from .orpo_config import ORPOConfig +from .utils import ( + DPODataCollatorWithPadding, + add_bos_token_if_needed, + add_eos_token_if_needed, + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + pad_to_length, + peft_module_casting_to_bf16, + selective_log_softmax, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + + +if is_wandb_available(): + import wandb + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +class ORPOTrainer(Trainer): + r""" + Initialize ORPOTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`ORPOConfig`): + The ORPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + """ + + _tag_names = ["trl", "orpo"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + args: Optional[ORPOConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None, + ): + if args.model_init_kwargs is None: + model_init_kwargs = {} + elif not isinstance(model, str): + raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.") + else: + model_init_kwargs = args.model_init_kwargs + torch_dtype = model_init_kwargs.get("torch_dtype") + if torch_dtype is not None: + # Convert to `torch.dtype` if an str is passed + if isinstance(torch_dtype, str) and torch_dtype != "auto": + torch_dtype = getattr(torch, torch_dtype) + if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"Invalid `torch_dtype` passed to the ORPOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}." + ) + model_init_kwargs["torch_dtype"] = torch_dtype + + if isinstance(model, str): + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + + # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` + # has been called in order to properly call autocast if needed. + self._peft_has_been_casted_to_bf16 = False + + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_config, we merge and unload it first + if isinstance(model, PeftModel): + model = model.merge_and_unload() + + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if _support_gc_kwargs: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # get peft model with the given config + model = get_peft_model(model, peft_config) + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) + # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager + self._peft_has_been_casted_to_bf16 = True + + # For models that use gradient_checkpointing, we need to attach a hook that enables input + # to explicitly have `requires_grad=True`, otherwise training will either silently + # fail or completely fail. + elif args.gradient_checkpointing: + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + if args.generate_during_eval and not (is_wandb_available() or is_comet_available()): + raise ValueError( + "`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + " Please install `wandb` or `comet-ml` to resolve." + ) + + if model is not None: + self.is_encoder_decoder = model.config.is_encoder_decoder + elif args.is_encoder_decoder is None: + raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.") + else: + self.is_encoder_decoder = args.is_encoder_decoder + + if self.is_encoder_decoder: + self.decoder_start_token_id = model.config.decoder_start_token_id + self.pad_token_id = model.config.pad_token_id + + if processing_class is None: + raise ValueError("processing_class must be specified to tokenize a ORPO dataset.") + if args.max_length is None: + warnings.warn( + "`max_length` is not set in the ORPOConfig's init" + " it will default to `512` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_length = 512 + else: + max_length = args.max_length + if args.max_prompt_length is None: + warnings.warn( + "`max_prompt_length` is not set in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + max_prompt_length = 128 + else: + max_prompt_length = args.max_prompt_length + + if args.max_completion_length is None and self.is_encoder_decoder: + warnings.warn( + "When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init" + " it will default to `128` by default, but you should do it yourself in the future.", + UserWarning, + ) + self.max_completion_length = 128 + else: + self.max_completion_length = args.max_completion_length + + if data_collator is None: + data_collator = DPODataCollatorWithPadding( + pad_token_id=processing_class.pad_token_id, + label_pad_token_id=args.label_pad_token_id, + is_encoder_decoder=self.is_encoder_decoder, + ) + + if args.remove_unused_columns: + args.remove_unused_columns = False + # warn users + warnings.warn( + "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_dpo_data_collator = True + else: + self.use_dpo_data_collator = False + + # Disable dropout in the model and reference model + if args.disable_dropout: + disable_dropout_in_model(model) + + self.max_length = max_length + self.generate_during_eval = args.generate_during_eval + self.label_pad_token_id = args.label_pad_token_id + self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id + self.max_prompt_length = max_prompt_length + self.truncation_mode = args.truncation_mode + self.processing_class = processing_class + + self.beta = args.beta + self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0) + if self.aux_loss_enabled and self.aux_loss_coef == 0.0: + warnings.warn( + "You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to " + "`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value " + "greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary " + "loss.", + UserWarning, + ) + + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in ORPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "prompt_input_ids", "chosen_input_ids", and + # "rejected_input_ids". As a result, the trainer issues the warning: "Could not estimate the number of tokens + # of the input, floating-point operations will not be computed." To suppress this warning, we set the + # "estimate_tokens" key in the model's "warnings_issued" dictionary to True. This acts as a flag to indicate + # that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + # Compute that only on the main process for faster data processing. + # see: https://github.com/huggingface/trl/pull/1255 + with PartialState().main_process_first(): + # Extract the prompt if needed, and apply the chat template if needed + train_dataset = train_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + train_dataset = train_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, num_proc=args.dataset_num_proc + ) + train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + if eval_dataset is not None: + eval_dataset = eval_dataset.map(maybe_extract_prompt, num_proc=args.dataset_num_proc) + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, + fn_kwargs={"tokenizer": processing_class}, + num_proc=args.dataset_num_proc, + ) + eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + if not hasattr(self, "accelerator"): + raise AttributeError( + "Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`." + ) + + def build_tokenized_answer(self, prompt, answer): + """ + Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`. + It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`. + Reference: + https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + """ + + full_tokenized = self.processing_class(prompt + answer, add_special_tokens=False) + prompt_input_ids = self.processing_class(prompt, add_special_tokens=False)["input_ids"] + + answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :] + answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :] + + # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` + full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids]) + + # Prepare input tokens for token by token comparison + full_input_ids = np.array(full_tokenized["input_ids"]) + + if len(full_input_ids) != len(full_concat_input_ids): + raise ValueError("Prompt input ids and answer input ids should have the same length.") + + # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens + # can be merged together when tokenizing prompt+answer. This could result + # on the last token from the prompt being different when tokenized on its own + # vs when done as prompt+answer. + response_token_ids_start_idx = len(prompt_input_ids) + + # If tokenized prompt is different than both prompt+answer, then it means the + # last token has changed due to merging. + if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]: + response_token_ids_start_idx -= 1 + + prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx] + prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx] + + if len(prompt_input_ids) != len(prompt_attention_mask): + raise ValueError("Prompt input ids and attention mask should have the same length.") + + answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] + answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] + + return dict( + prompt_input_ids=prompt_input_ids, + prompt_attention_mask=prompt_attention_mask, + input_ids=answer_input_ids, + attention_mask=answer_attention_mask, + ) + + def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict: + """Tokenize a single row from a ORPO specific dataset. + + At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation + in case the prompt + chosen or prompt + rejected responses is/are too long. First + we truncate the prompt; if we're still too long, we truncate the chosen/rejected. + + We also create the labels for the chosen/rejected responses, which are of length equal to + the sum of the length of the prompt and the chosen/rejected response, with + label_pad_token_id for the prompt tokens. + """ + batch = {} + prompt = feature["prompt"] + chosen = feature["chosen"] + rejected = feature["rejected"] + + if not self.is_encoder_decoder: + # Check issues below for more details + # 1. https://github.com/huggingface/trl/issues/907 + # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 + # 3. https://github.com/LianjiaTech/BELLE/issues/337 + + if not isinstance(prompt, str): + raise ValueError(f"prompt should be an str but got {type(prompt)}") + prompt_tokens = self.processing_class(prompt, add_special_tokens=False) + prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()} + + if not isinstance(chosen, str): + raise ValueError(f"chosen should be an str but got {type(chosen)}") + chosen_tokens = self.build_tokenized_answer(prompt, chosen) + + if not isinstance(rejected, str): + raise ValueError(f"rejected should be an str but got {type(rejected)}") + rejected_tokens = self.build_tokenized_answer(prompt, rejected) + + # Last prompt token might get merged by tokenizer and + # it should not be included for generation if that happens + prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"]) + + chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"]) + rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"]) + prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids) + + for k, v in prompt_tokens.items(): + prompt_tokens[k] = v[:prompt_len_input_ids] + + # Make sure prompts only have one different token at most an + # and length only differs by 1 at most + num_diff_tokens = sum( + [a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])] + ) + num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids) + if num_diff_tokens > 1 or num_diff_len > 1: + raise ValueError( + "Chosen and rejected prompt_input_ids might only differ on the " + "last token due to tokenizer merge ops." + ) + + # add BOS token to head of prompt. Avoid adding if it's already there + prompt_tokens, chosen_tokens, rejected_tokens = add_bos_token_if_needed( + self.processing_class.bos_token_id, + prompt_len_input_ids, + prompt_tokens, + chosen_prompt_len_input_ids, + chosen_tokens, + rejected_prompt_len_input_ids, + rejected_tokens, + ) + + # add EOS token to end of answer. Avoid adding if it's already there + chosen_tokens, rejected_tokens = add_eos_token_if_needed( + self.processing_class.eos_token_id, chosen_tokens, rejected_tokens + ) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt + for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + if self.truncation_mode == "keep_start": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_prompt_length] + elif self.truncation_mode == "keep_end": + for k in ["prompt_input_ids", "prompt_attention_mask"]: + answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :] + else: + raise ValueError(f"Unknown truncation mode: {self.truncation_mode}") + + # if that's still too long, truncate the response + for answer_tokens in [chosen_tokens, rejected_tokens]: + if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length: + for k in ["input_ids", "attention_mask"]: + answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length] + + # Create labels + chosen_sequence_tokens = { + k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"] + } + rejected_sequence_tokens = { + k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"] + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(chosen_tokens["prompt_input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [ + self.label_pad_token_id + ] * len(rejected_tokens["prompt_input_ids"]) + + for k, toks in { + "chosen_": chosen_sequence_tokens, + "rejected_": rejected_sequence_tokens, + "": prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}{type_key}"] = tokens + + else: + chosen_tokens = self.processing_class( + chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + rejected_tokens = self.processing_class( + rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True + ) + prompt_tokens = self.processing_class( + prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True + ) + + batch["chosen_labels"] = chosen_tokens["input_ids"] + batch["rejected_labels"] = rejected_tokens["input_ids"] + batch["prompt_input_ids"] = prompt_tokens["input_ids"] + batch["prompt_attention_mask"] = prompt_tokens["attention_mask"] + + if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"): + batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["rejected_labels"]) + ) + batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels( + labels=torch.tensor(batch["chosen_labels"]) + ) + + if is_torch_xla_available(): + # Pad the sequences to global max_length to avoid TorchXLA recompilation + for k in batch: + if "labels" in k or self.is_encoder_decoder: + pad_value = self.label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = self.padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + batch[k] = batch[k] + [pad_value] * (self.max_length - len(batch[k])) + return batch + + @staticmethod + def concatenated_inputs( + batch: dict[str, Union[list, torch.LongTensor]], + is_encoder_decoder: bool = False, + label_pad_token_id: int = -100, + padding_value: int = 0, + device: Optional[torch.device] = None, + ) -> dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length). + is_encoder_decoder: Whether the model is an encoder-decoder model. + label_pad_token_id: The label pad token id. + padding_value: The padding value to use for the concatenated inputs_ids. + device: The device for the concatenated inputs. + + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + concatenated_batch = {} + + if is_encoder_decoder: + max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1]) + else: + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + if "labels" in k or is_encoder_decoder: + pad_value = label_pad_token_id + elif k.endswith("_input_ids"): + pad_value = padding_value + elif k.endswith("_attention_mask"): + pad_value = 0 + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ).to(device=device) + + if is_encoder_decoder: + concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device) + concatenated_batch["concatenated_attention_mask"] = ( + batch["prompt_attention_mask"].repeat(2, 1).to(device=device) + ) + + return concatenated_batch + + def odds_ratio_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities. + + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the ORPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. + The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes. + The `log(sigmoid(log_odds_chosen))` for logging purposes. + """ + + # Derived from Eqs. (4) and (7) from https://huggingface.co/papers/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x) + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps)) + ) + ratio = F.logsigmoid(log_odds) + losses = self.beta * ratio + + chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach() + rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach() + + return losses, chosen_rewards, rejected_rewards, torch.mean(ratio), torch.mean(log_odds) + + @staticmethod + def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + label_pad_token_id: int = -100, + is_encoder_decoder: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. + label_pad_token_id: The label pad token id. + is_encoder_decoder: Whether the model is an encoder-decoder model. + + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + if not is_encoder_decoder: + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels = torch.where(labels == label_pad_token_id, 0, labels) + + per_token_logps = selective_log_softmax(logits, labels) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]] + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs( + batch, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + padding_value=self.padding_value, + device=self.accelerator.device, + ) + len_chosen = batch["chosen_labels"].shape[0] + + model_kwargs = ( + { + "decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]), + } + if self.is_encoder_decoder + else {} + ) + + if self.aux_loss_enabled: + model_kwargs["output_router_logits"] = True + + outputs = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + use_cache=False, + **model_kwargs, + ) + all_logits = outputs.logits + + def cross_entropy_loss(logits, labels): + if not self.is_encoder_decoder: + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + logits = logits.view(-1, logits.shape[-1]) + labels = labels.view(-1) + # Enable model parallelism + labels = labels.to(logits.device) + loss = loss_fct(logits, labels) + return loss + + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + # orpo chosen nll loss is computed over the full prompt and response + chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) + + all_logps = self.get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=True, + is_encoder_decoder=self.is_encoder_decoder, + label_pad_token_id=self.label_pad_token_id, + ) + + chosen_logps = all_logps[:len_chosen] + rejected_logps = all_logps[len_chosen:] + + if not self.is_encoder_decoder: + chosen_logits = all_logits[:len_chosen, :-1, :] + rejected_logits = all_logits[len_chosen:, :-1, :] + else: + chosen_logits = all_logits[:len_chosen] + rejected_logits = all_logits[len_chosen:] + + if self.aux_loss_enabled: + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss) + + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss) + + def get_batch_loss_metrics( + self, + model, + batch: dict[str, Union[list, torch.LongTensor]], + train_eval: Literal["train", "eval"] = "train", + ): + """Compute the ORPO loss and other metrics for the given batch of inputs for train or test.""" + metrics = {} + + forward_output = self.concatenated_forward(model, batch) + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + policy_nll_loss, + ) = forward_output[:5] + if self.aux_loss_enabled: + aux_loss = forward_output[5] + + losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss( + policy_chosen_logps, policy_rejected_logps + ) + # full ORPO loss + loss = policy_nll_loss - losses.mean() + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + prefix = "eval_" if train_eval == "eval" else "" + metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(chosen_rewards).mean() + metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(rejected_rewards).mean() + metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(reward_accuracies).mean() + metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics( + chosen_rewards - rejected_rewards + ).mean() + metrics[f"{prefix}logps/rejected"] = self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean() + metrics[f"{prefix}logps/chosen"] = self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean() + metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics( + policy_rejected_logits.detach().mean() + ).mean() + metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics( + policy_chosen_logits.detach().mean() + ).mean() + metrics[f"{prefix}nll_loss"] = self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean() + metrics[f"{prefix}log_odds_ratio"] = self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean() + metrics[f"{prefix}log_odds_chosen"] = self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean() + if is_torch_xla_available(): + xm.mark_step() # needed because .item() calls + for k, v in metrics.items(): + metrics[k] = v.item() + if self.aux_loss_enabled: + loss += self.aux_loss_coef * aux_loss + + return loss, metrics + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + compute_loss_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with compute_loss_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train") + + # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: + loss = loss.to(self.args.device) + + # force log the metrics + self.store_metrics(metrics, train_eval="train") + + if return_outputs: + return (loss, metrics) + return loss + + def generate_from_model(self, model, batch: dict[str, torch.LongTensor]) -> str: + """Generate samples from the model and reference model for the given batch of inputs.""" + + # If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with + # the torch amp context manager as some hidden states are silently casted to full precision. + generate_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with generate_context_manager: + policy_output = model.generate( + input_ids=batch["prompt_input_ids"], + attention_mask=batch["prompt_attention_mask"], + max_length=self.max_length, + do_sample=True, + pad_token_id=self.processing_class.pad_token_id, + ) + + policy_output = pad_to_length(policy_output, self.max_length, self.processing_class.pad_token_id) + policy_output_decoded = self.processing_class.batch_decode(policy_output, skip_special_tokens=True) + + return policy_output_decoded + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ): + if not self.use_dpo_data_collator: + warnings.warn( + "prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than " + "DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator" + ) + if ignore_keys is None: + if hasattr(model, "config"): + ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + prediction_context_manager = ( + autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() + ) + + with torch.no_grad(), prediction_context_manager: + loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval") + + # force log the metrics + self.store_metrics(metrics, train_eval="eval") + + if prediction_loss_only: + return (loss.detach(), None, None) + + # logits for the chosen and rejected samples from model + logits_dict = { + "eval_logits/chosen": metrics["eval_logits/chosen"], + "eval_logits/rejected": metrics["eval_logits/rejected"], + } + logits = [v for k, v in logits_dict.items() if k not in ignore_keys] + logits = torch.tensor(logits, device=self.accelerator.device) + labels = torch.zeros(logits.shape[0], device=self.accelerator.device) + + return (loss.detach(), logits, labels) + + def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[list[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Overriding built-in evaluation loop to store metrics for each batch. + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + + # Sample and save to game log if requested (for one batch to save time) + if self.generate_during_eval: + # Generate random indices within the range of the total number of samples + num_samples = len(dataloader.dataset) + random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size) + + # Use dataloader.dataset.select to get the random batch without iterating over the DataLoader + random_batch_dataset = dataloader.dataset.select(random_indices) + random_batch = self.data_collator(random_batch_dataset) + random_batch = self._prepare_inputs(random_batch) + + policy_output_decoded = self.generate_from_model(self.model, random_batch) + + table = pd.DataFrame( + columns=["Prompt", "Policy"], + data=[ + [prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded) + ], + ) + if "wandb" in self.args.report_to: + wandb.log({"game_log": wandb.Table(data=table)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="game_log.csv", + table=table, + ) + + # Base evaluation + initial_output = super().evaluation_loop( + dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix + ) + + return initial_output + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`dict[str, float]`): + The values to log. + start_time (`float` or `None`, *optional*, defaults to `None`): + Start time of the training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + return super().log(logs, start_time) + + def _shift_right(self, input_ids): + if self.decoder_start_token_id is None: + raise ValueError( + "model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id." + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = self.decoder_start_token_id + + if self.pad_token_id is None: + raise ValueError("model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id) + + return shifted_input_ids + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{hong2024orpo, + title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, + author = {Jiwoo Hong and Noah Lee and James Thorne}, + year = 2024, + eprint = {arXiv:2403.07691} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="ORPO", + trainer_citation=citation, + paper_title="ORPO: Monolithic Preference Optimization without Reference Model", + paper_id="2403.07691", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..771cea181e06834a8f71e879aa06af97bb665606 --- /dev/null +++ b/trl/trainer/ppo_config.py @@ -0,0 +1,135 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Literal, Optional + +from ..trainer.utils import OnPolicyConfig + + +@dataclass +class PPOConfig(OnPolicyConfig): + r""" + Configuration class for the [`PPOTrainer`]. + + This class includes only the parameters that are specific to PPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default + values in this class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): + Name of this experiment. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + model_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. + kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`): + Which estimator for KL-Divergence to use from [Approximating KL Divergence](http://joschu.net/blog/kl-approx.html). + Defaults to "k1", a straightforward, unbiased estimator. Can be set to "k3", an unbiased estimator with + lower variance which "appears to be a strictly better estimator". Cannot be set to "k2", as it is used for + logging purposes. + cliprange (`float`, *optional*, defaults to `0.2`): + Clip range. + vf_coef (`float`, *optional*, defaults to `0.1`): + Value function coefficient. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Clip range for the value function. + gamma (`float`, *optional*, defaults to `1.0`): + Discount factor. + lam (`float`, *optional*, defaults to `0.95`): + Lambda value for GAE. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. + """ + + exp_name: str = field( + default=os.path.basename(__file__)[:-3], + metadata={"help": "Name of this experiment."}, + ) + reward_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the reward model."}, + ) + model_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, + ) + ref_adapter_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + num_ppo_epochs: int = field( + default=4, + metadata={"help": "Number of epochs to train."}, + ) + whiten_rewards: bool = field( + default=False, + metadata={"help": "Whether to whiten the rewards."}, + ) + kl_coef: float = field( + default=0.05, + metadata={"help": "KL coefficient."}, + ) + kl_estimator: Literal["k1", "k3"] = field( + default="k1", + metadata={ + "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence " + "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be " + "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better " + "estimator'. Cannot be set to 'k2', as it is used for logging purposes." + }, + ) + cliprange: float = field( + default=0.2, + metadata={"help": "Clip range."}, + ) + vf_coef: float = field( + default=0.1, + metadata={"help": "Value function coefficient."}, + ) + cliprange_value: float = field( + default=0.2, + metadata={"help": "Clip range for the value function."}, + ) + gamma: float = field( + default=1.0, + metadata={"help": "Discount factor."}, + ) + lam: float = field( + default=0.95, + metadata={"help": "Lambda value for GAE."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..38cc058522ed3f3a66ac4f42ce0e0c9535da9e7e --- /dev/null +++ b/trl/trainer/ppo_trainer.py @@ -0,0 +1,816 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils import is_peft_available, is_rich_available + +from ..core import masked_mean, masked_whiten +from ..models import create_reference_model +from ..models.utils import unwrap_model_for_generation +from .ppo_config import PPOConfig +from .utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + empty_cache, + exact_div, + first_true_indices, + forward, + generate_model_card, + get_comet_experiment_url, + get_reward, + log_table_to_comet_experiment, + peft_module_casting_to_bf16, + prepare_deepspeed, + print_rich_table, + selective_log_softmax, + truncate_response, +) + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + +if is_wandb_available(): + import wandb + + +INVALID_LOGPROB = 1.0 + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, value_model) -> None: + super().__init__() + self.policy = policy + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) + + def forward(self, **kwargs): + output = self.critic_backbone(**kwargs) + logits = self.value_model.score(output.hidden_states[-1]) + return self.policy(**kwargs), logits + + +class PPOTrainer(Trainer): + _tag_names = ["trl", "ppo"] + + def __init__( + self, + args: PPOConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + model: nn.Module, + ref_model: Optional[nn.Module], + reward_model: nn.Module, + train_dataset: Dataset, + value_model: nn.Module, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + # less commonly used + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[list[TrainerCallback]] = None, + peft_config: Optional["PeftConfig"] = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must make a copy of it, or `None` if you use peft." + ) + + self.args = args + self.processing_class = processing_class + self.policy_model = model + + # Define the collator if not provided + if data_collator is None: + data_collator = DataCollatorWithPadding(self.processing_class) + + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int + + # Check that the kl estimator is valid + if self.args.kl_estimator not in {"k1", "k3"}: + raise ValueError( + "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, " + "appears to be a strictly better estimator). See " + "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." + ) + + # peft support + if not is_peft_available() and peft_config is not None: + raise ImportError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_confg, we merge and unload it first + if isinstance(self.policy_model, PeftModel): + self.policy_model = self.policy_model.merge_and_unload() + + # get peft model with the given config + self.policy_model = get_peft_model(self.policy_model, peft_config) + if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(self.policy_model) + + self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) + self.model_adapter_name = args.model_adapter_name + self.ref_adapter_name = args.ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + self.ref_model = None + else: + self.ref_model = create_reference_model(self.policy_model) + + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert args.local_mini_batch_size >= 8, ( + f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + ) + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: + if module is not None: + disable_dropout_in_model(module) + self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) + self.model.config = self.policy_model.config # needed for pushing to hub + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=self.data_collator, + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=self.data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = prepare_deepspeed( + self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = self.ref_model.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.policy.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.policy.set_adapter(self.model_adapter_name or "default") + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_model + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + logprob = selective_log_softmax(logits, response) + del logits + empty_cache() + + if ref_policy is None: + with self.null_ref_context(): + ref_output = forward(model.policy, query_response, processing_class.pad_token_id) + else: + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits + empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators + logr = ref_logprobs - logprobs + kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3 + non_score_reward = -args.kl_coef * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_logprobs = selective_log_softmax(logits, mb_responses) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, + vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, + mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + rlhf_reward = mean_non_score_reward + scores.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + logprobs, + ref_logprobs, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + if is_rich_available(): + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="PPO", + trainer_citation=citation, + paper_title="Fine-Tuning Language Models from Human Preferences", + paper_id="1909.08593", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/prm_config.py b/trl/trainer/prm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ab93ec09782c4856666959037dd665b62eee77 --- /dev/null +++ b/trl/trainer/prm_config.py @@ -0,0 +1,112 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class PRMConfig(TrainingArguments): + r""" + Configuration class for the [`PRMTrainer`]. + + This class includes only the parameters that are specific to PRM training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) used for truncation. + max_prompt_length (`int` or `None`, *optional*, defaults to `512`): + Maximum length of the prompt used for truncation. + max_completion_length (`int` or `None`, *optional*, defaults to `None`): + Maximum length of the completion used for truncation. The completion is the concatenation of the steps. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + step_separator (`str`, *optional*, defaults to `"\n"`): + Separator used to separate each step of the reasoning process. + train_on_last_step_only (`bool`, *optional*, defaults to `False`): + Whether to train only on the last step. + dataset_num_proc (`int`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + """ + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=1e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize " + "num_tokens_in_batch for precise loss calculation. Reference: https://github.com/huggingface/transformers/issues/34242 " + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."}, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={"help": "Maximum length of the prompt used for truncation."}, + ) + max_completion_length: Optional[int] = field( + default=None, + metadata={ + "help": "Maximum length of the completion used for truncation. The completion is the concatenation of the " + "steps." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + step_separator: str = field( + default="\n", + metadata={"help": "Separator used to separate each step of the reasoning process."}, + ) + train_on_last_step_only: bool = field( + default=False, + metadata={"help": "Whether to train only on the last step."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) diff --git a/trl/trainer/prm_trainer.py b/trl/trainer/prm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b69bb71052838afcf4cb826ca9313173c2164903 --- /dev/null +++ b/trl/trainer/prm_trainer.py @@ -0,0 +1,360 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import textwrap +import warnings +from itertools import chain +from pathlib import Path +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn +from accelerate import PartialState +from datasets import Dataset, features +from transformers import ( + BaseImageProcessor, + DataCollator, + DataCollatorForTokenClassification, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from .prm_config import PRMConfig +from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +class PRMTrainer(Trainer): + """ + Initialize PRMTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForTokenClassification`. + args (`PRMConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + + _tag_names = ["trl", "prm"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + args: Optional[PRMConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + ): + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn( + "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " + "please update to the latest version of peft to use `gradient_checkpointing_kwargs`." + ) + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + model = get_peft_model(model, peft_config) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if processing_class is None: + raise ValueError( + "A processing_class must be specified when using the default DataCollatorForTokenClassification" + ) + data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length) + + if "input_ids" not in train_dataset.column_names: + with PartialState().main_process_first(): + fn_kwargs = { + "tokenizer": processing_class, + "step_separator": args.step_separator, + "max_length": args.max_length, + "max_prompt_length": args.max_prompt_length, + "max_completion_length": args.max_completion_length, + "train_on_last_step_only": args.train_on_last_step_only, + } + train_fn_kwargs = {**fn_kwargs, "is_eval": False} + train_dataset = train_dataset.map( + self.tokenize_row, + fn_kwargs=train_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=train_dataset.features, + desc="Tokenizing train dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + eval_fn_kwargs = {**fn_kwargs, "is_eval": True} + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + self.tokenize_row, + fn_kwargs=eval_fn_kwargs, + num_proc=args.dataset_num_proc, + remove_columns=eval_dataset.features, + desc="Tokenizing eval dataset", + features=features.Features( # needed to avoid map to cast labels to bool + { + "labels": features.Sequence(features.Value("int64")), + "input_ids": features.Sequence(features.Value("int64")), + } + ), + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + @staticmethod + def tokenize_row( + features, + tokenizer, + step_separator, + max_length, + max_prompt_length, + max_completion_length, + train_on_last_step_only, + is_eval, + ): + r""" + Tokenize a row of the dataset. + + Args: + features (`dict[str, str]`): + Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`. + tokenizer (`PreTrainedTokenizerBase`): + Tokenizer used to process the data. + step_separator (`str`): + Separator between steps in the completion. + max_length (`int` or `None`): + Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated. + max_prompt_length (`int` or `None`): + Maximum length of the prompt. If `None`, the prompt is not truncated. + max_completion_length (`int` or `None`): + Maximum length of the completion sequences. If `None`, the completion sequences are not truncated. + train_on_last_step_only (`bool`): + Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last + token of the completion. + is_eval (`bool`): + Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`. + + Returns: + `dict[str, list[int]]`: + Tokenized sequences with the keys `"input_ids"`, and `"labels". + + Example: + ```python + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B") + >>> features = {"prompt": "Which number is larger, 9.8 or 9.11?", + ... "completions": ["11 is greater than 8.", + ... "Hence, 9.11 > 9.8."], + ... "labels": [True, False]} + >>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False) + {'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198], + 'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]} + ``` + """ + # Tokenize the prompt and completions + prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"] + completions_ids = [ + tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"] + ] + if train_on_last_step_only and not is_eval: + labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])] + else: + labels = [int(label) for label in features["labels"]] + + # Get the ID of the separator token and add it to the completions + separator_ids = tokenizer.encode(step_separator, add_special_tokens=False) + completions_ids = [completion + separator_ids for completion in completions_ids] + + # Create the label + labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)] + + # Join the completions and labels steps + completion_ids = list(chain(*completions_ids)) + labels = list(chain(*labels)) + + if tokenizer.bos_token_id is not None: + prompt_ids = [tokenizer.bos_token_id] + prompt_ids + + # Truncate prompt and completion sequences + if max_prompt_length is not None: + prompt_ids = prompt_ids[-max_prompt_length:] + if max_completion_length is not None: + completion_ids = completion_ids[:max_completion_length] + labels = labels[:max_completion_length] + + input_ids = prompt_ids + completion_ids + labels = [-100] * len(prompt_ids) + labels + + if max_length is not None: + input_ids = input_ids[:max_length] + labels = labels[:max_length] + + return {"input_ids": input_ids, "labels": labels} + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{uesato2022solving, + title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}}, + author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina}, + year = 2022, + journal = {arXiv preprint arXiv:2211.14275} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PRM", + trainer_citation=citation, + paper_title="Solving math word problems with process-and outcome-based feedback", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/reward_config.py b/trl/trainer/reward_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0dc6f3145022fc5665d918e427265a37ccdae0 --- /dev/null +++ b/trl/trainer/reward_config.py @@ -0,0 +1,105 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments + + +@dataclass +class RewardConfig(TrainingArguments): + r""" + Configuration class for the [`RewardTrainer`]. + + This class includes only the parameters that are specific to Reward training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the + limit. This argument is required if you want to use the default data collator. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + dataset_num_proc (`int`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + center_rewards_coefficient (`float`, *optional*, defaults to `None`): + Coefficient to incentivize the reward model to output mean-zero rewards (proposed by + https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. + remove_unused_columns (`bool`, *optional*, defaults to `False`): + Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if + the dataset is pretokenized. + """ + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize " + "num_tokens_in_batch for precise loss calculation. Reference: https://github.com/huggingface/transformers/issues/34242 " + }, + ) + + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " + "exceed the limit. This argument is required if you want to use the default data collator." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model and reference model."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + center_rewards_coefficient: Optional[float] = field( + default=None, + metadata={ + "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " + "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." + }, + ) + remove_unused_columns: bool = field( + default=False, + metadata={ + "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " + "if the dataset is pretokenized." + }, + ) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..50c8084816bfe0c9158bf91d16e023e69132fefb --- /dev/null +++ b/trl/trainer/reward_trainer.py @@ -0,0 +1,422 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import warnings +from collections import defaultdict +from dataclasses import FrozenInstanceError, replace +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import pandas as pd +import torch +import torch.nn as nn +from accelerate import PartialState +from accelerate.utils import gather_object +from datasets import Dataset +from transformers import ( + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + is_wandb_available, +) +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_pt_utils import nested_detach +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available, is_rich_available + +from ..data_utils import maybe_apply_chat_template +from .reward_config import RewardConfig +from .utils import ( + RewardDataCollatorWithPadding, + compute_accuracy, + decode_and_strip_padding, + disable_dropout_in_model, + generate_model_card, + get_comet_experiment_url, + log_table_to_comet_experiment, + print_rich_table, +) + + +if is_peft_available(): + from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]: + """Tokenize a batch from a reward modelling dataset.""" + new_examples = { + "input_ids_chosen": [], + "attention_mask_chosen": [], + "input_ids_rejected": [], + "attention_mask_rejected": [], + } + for chosen, rejected in zip(batch["chosen"], batch["rejected"]): + tokenized_chosen = tokenizer(chosen) + tokenized_rejected = tokenizer(rejected) + new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"]) + new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"]) + new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"]) + new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"]) + + return new_examples + + +class RewardTrainer(Trainer): + _tag_names = ["trl", "reward-trainer"] + + def __init__( + self, + model: Optional[Union[PreTrainedModel, nn.Module]] = None, + args: Optional[RewardConfig] = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( + None, + None, + ), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional[dict] = None, + ): + """ + Initialize RewardTrainer. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForSequenceClassification`. + args (`RewardConfig`): + The arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + model_init (`Callable[[], transformers.PreTrainedModel]`): + The model initializer to use for training. If None is specified, the default model initializer will be used. + compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`): + The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + peft_config (`dict`, defaults to `None`): + The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. + """ + if not is_peft_available() and peft_config is not None: + raise ValueError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + if not isinstance(model, PeftModel): + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False): + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + + prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + + if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + warnings.warn( + "You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. " + "please update to the latest version of peft to use `gradient_checkpointing_kwargs`.", + UserWarning, + ) + elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None: + prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + + model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + model = get_peft_model(model, peft_config) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + if compute_metrics is None: + compute_metrics = compute_accuracy + + if data_collator is None: + if processing_class is None: + raise ValueError( + "A processing_class must be specified when using the default RewardDataCollatorWithPadding" + ) + + max_length = args.max_length + + data_collator = RewardDataCollatorWithPadding(processing_class) + + if args.remove_unused_columns: + try: # for bc before https://github.com/huggingface/transformers/pull/25435 + args.remove_unused_columns = False + except FrozenInstanceError: + args = replace(args, remove_unused_columns=False) + # warn users + warnings.warn( + "When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig" + " we have set it for you, but you should do it yourself in the future.", + UserWarning, + ) + + self.use_reward_data_collator = True + else: + self.use_reward_data_collator = False + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the + # "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result, + # the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point + # operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's + # "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been + # issued. + model.warnings_issued["estimate_tokens"] = True + + if "input_ids_chosen" not in train_dataset.column_names: + with PartialState().main_process_first(): + fn_kwargs = {"tokenizer": processing_class} + train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}) + train_dataset = train_dataset.map( + _tokenize, + batched=True, + fn_kwargs=fn_kwargs, + num_proc=args.dataset_num_proc, + ) + # This filter is important because otherwise you get samples that exceed the model's context length and + # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the + # user might get surprised if N samples are missing from training. + train_dataset = train_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length, + num_proc=args.dataset_num_proc, + ) + if eval_dataset is not None: + eval_dataset = eval_dataset.map( + maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class} + ) + eval_dataset = eval_dataset.map( + _tokenize, + fn_kwargs=fn_kwargs, + batched=True, + num_proc=args.dataset_num_proc, + ) + # This filter is important because otherwise you get samples that exceed the model's context length and + # get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the + # user might get surprised if N samples are missing from training. + eval_dataset = eval_dataset.filter( + lambda x: len(x["input_ids_chosen"]) <= max_length + and len(x["input_ids_rejected"]) <= max_length, + num_proc=args.dataset_num_proc, + ) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + model_init=model_init, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + def compute_loss( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + return_outputs=False, + num_items_in_batch=None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: + rewards_chosen = model( + input_ids=inputs["input_ids_chosen"], + attention_mask=inputs["attention_mask_chosen"], + return_dict=True, + )["logits"] + rewards_rejected = model( + input_ids=inputs["input_ids_rejected"], + attention_mask=inputs["attention_mask_rejected"], + return_dict=True, + )["logits"] + # calculate loss, optionally modulate with margin + if "margin" in inputs: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() + else: + loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean() + + if self.args.center_rewards_coefficient is not None: + loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) + + if return_outputs: + return loss, { + "rewards_chosen": rewards_chosen, + "rewards_rejected": rewards_rejected, + } + return loss + + def prediction_step( + self, + model: Union[PreTrainedModel, nn.Module], + inputs: dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[list[str]] = None, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + with torch.no_grad(): + loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) + + if prediction_loss_only: + return (loss, None, None) + + loss = loss.detach() + logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) + logits = nested_detach(logits) + # Stack accepted against rejected, mean over logits + # and softmax to get preferences between accepted and rejected to sum to 1 + logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T + + labels = torch.zeros(logits.shape[0]) + labels = self._prepare_inputs(labels) + + return loss, logits, labels + + def evaluate(self, *args, **kwargs): + num_print_samples = kwargs.pop("num_print_samples", 4) + self.visualize_samples(num_print_samples) + return super().evaluate(*args, **kwargs) + + def visualize_samples(self, num_print_samples: int): + """ + Visualize the reward model logits prediction + + Args: + num_print_samples (`int`, defaults to `4`): + The number of samples to print. Set to `-1` to print all samples. + """ + eval_dataloader = self.get_eval_dataloader() + table = defaultdict(list) + for _, inputs in enumerate(eval_dataloader): + _, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False) + chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class) + rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class) + table["chosen_text"].extend(gather_object(chosen_text)) + table["rejected_text"].extend(gather_object(rejected_text)) + table["logits"].extend( + gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()]) + ) + if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples: + break + df = pd.DataFrame(table) + if self.accelerator.process_index == 0: + if is_rich_available(): + print_rich_table(df[:num_print_samples]) + if "wandb" in self.args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in self.args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="Reward", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e13680a1cc8231a25ea829b7b5972f48a0f0019f --- /dev/null +++ b/trl/trainer/rloo_config.py @@ -0,0 +1,114 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field + +from ..trainer.utils import OnPolicyConfig + + +@dataclass +class RLOOConfig(OnPolicyConfig): + r""" + Configuration class for the [`RLOOTrainer`]. + + This class includes only the parameters that are specific to RLOO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default + values in this class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`): + Name of this experiment. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. + cliprange (`float`, *optional*, defaults to `0.2`): + Clip range. + rloo_k (`int`, *optional*, defaults to `2`): + REINFORCE Leave-One-Out (RLOO) number of online samples per prompt. + normalize_reward (`bool`, *optional*, defaults to `False`): + Whether to normalize rewards. + reward_clip_range (`float`, *optional*, defaults to `10.0`): + Clip range for rewards. + normalize_advantage (`bool`, *optional*, defaults to `False`): + Whether to normalize advantages. + token_level_kl (`bool`, *optional*, defaults to `True`): + Whether to use token-level KL penalty or sequence-level KL penalty. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. + """ + + exp_name: str = field( + default=os.path.basename(__file__)[:-3], + metadata={"help": "Name of this experiment."}, + ) + reward_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the reward model."}, + ) + num_ppo_epochs: int = field( + default=4, + metadata={"help": "Number of epochs to train."}, + ) + whiten_rewards: bool = field( + default=False, + metadata={"help": "Whether to whiten the rewards."}, + ) + kl_coef: float = field( + default=0.05, + metadata={"help": "KL coefficient."}, + ) + cliprange: float = field( + default=0.2, + metadata={"help": "Clip range."}, + ) + rloo_k: int = field( + default=2, + metadata={"help": "REINFORCE Leave-One-Out (RLOO) number of online samples per prompt."}, + ) + normalize_reward: bool = field( + default=False, + metadata={"help": "Whether to normalize rewards"}, + ) + reward_clip_range: float = field( + default=10.0, + metadata={"help": "Clip range for rewards"}, + ) + normalize_advantage: bool = field( + default=False, + metadata={"help": "Whether to normalize advantages"}, + ) + token_level_kl: bool = field( + default=False, + metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..760d633f2541b4ab09dc3947bd522a82e4c5b8b2 --- /dev/null +++ b/trl/trainer/rloo_trainer.py @@ -0,0 +1,712 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from pathlib import Path +from typing import Callable, Optional, Union + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from accelerate import Accelerator +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainerCallback, + TrainerControl, + is_wandb_available, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils import is_rich_available + +from ..models.utils import unwrap_model_for_generation +from ..trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + prepare_deepspeed, + print_rich_table, + selective_log_softmax, + truncate_response, +) +from .rloo_config import RLOOConfig +from .utils import empty_cache, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment + + +if is_wandb_available(): + import wandb + +INVALID_LOGPROB = 1.0 + + +class RLOOTrainer(Trainer): + _tag_names = ["trl", "rloo"] + + def __init__( + self, + config: RLOOConfig, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ], + policy: nn.Module, + ref_policy: nn.Module, + reward_model: Union[nn.Module, Callable[[list[str]], list[float]]], + train_dataset: Dataset, + data_collator: Optional[DataCollatorWithPadding] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + # less commonly used + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: Optional[list[TrainerCallback]] = None, + ) -> None: + if ref_policy is policy: + raise ValueError( + "`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the " + "same as `policy`, you must mass a copy of it, or `None` if you use peft." + ) + + self.args = config + args = config + self.processing_class = processing_class + self.policy = policy + + # Define the collator if not provided + if data_collator is None: + data_collator = DataCollatorWithPadding(self.processing_class) + + self.policy.generation_config.eos_token_id = ( + None # disable `pad_token_id` and `eos_token_id` because we just want to + ) + self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding + + self.ref_policy = ref_policy + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = ( + args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches + ) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = exact_div( + args.local_batch_size, args.rloo_k, "`local_batch_size` must be a multiple of rloo_k" + ) # RLOO logic: needed because RLOO repeats the same prompt args.rloo_k times + + ######### + # setup model, optimizer, and others + ######### + for module in [policy, ref_policy, reward_model]: + if isinstance(module, nn.Module): + disable_dropout_in_model(module) + if args.stop_token and args.stop_token == "eos": + args.stop_token_id = self.processing_class.eos_token_id + self.model = policy + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + ### trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + self.backup_model = None + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + ### setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=self.data_collator, + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=self.data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + if isinstance(self.reward_model, nn.Module): + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.ref_policy = prepare_deepspeed( + self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + self.deepspeed = self.model + else: + self.ref_policy = self.ref_policy.to(self.accelerator.device) + if isinstance(self.reward_model, nn.Module): + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + self.model_wrapped = self.model + ref_policy = self.ref_policy + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = (args.num_total_batches * args.num_mini_batches) // 2 + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + queries = queries.repeat(args.rloo_k, 1) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + + # Generate responses and compute logprobs + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + # Process responses in batches + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + logprob = selective_log_softmax(logits, response) + del logits + empty_cache() + + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits + empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + + if isinstance(reward_model, nn.Module): + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + else: + score = torch.tensor( + reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(device) + + # Store batch results + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + + # Concatenate all batched results + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + del (logprob, ref_logprob, score) + empty_cache() + gc.collect() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_eos_token = torch.any(postprocessed_responses == processing_class.eos_token_id, dim=-1) + if args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + + # 4. compute rewards + # Compute KL divergence + kl = logprobs - ref_logprobs + + # Normalize rewards + if args.normalize_reward: + scores = (scores - scores.mean()) / (scores.std() + 1e-8) + scores = torch.clamp(scores, -args.reward_clip_range, args.reward_clip_range) + + # Compute total reward with KL penalty + if args.token_level_kl: + # Token-level KL penalty: apply KL penalty per token + kl_reward = -args.kl_coef * kl + + # Get the index of the last non-padded token for each sequence + eos_indices = padding_mask.size(1) - 1 - padding_mask.long().fliplr().argmax(dim=1, keepdim=True) + last_reward = torch.zeros_like(kl) + # Ensure scores has correct shape and type + scores_shaped = scores.reshape(-1, 1).to(kl.dtype) + last_reward.scatter_(dim=1, index=eos_indices, src=scores_shaped) + + # Combine KL reward and last reward + non_score_reward = kl_reward.sum(1) # Keep this for logging + reward = last_reward + kl_reward + rlhf_reward = reward.sum(1) # Sum across sequence length + else: + # Sequence-level KL penalty: sum KL across tokens first + sequence_kl = kl.sum(1) + non_score_reward = -args.kl_coef * sequence_kl + rlhf_reward = non_score_reward + scores + + # vectorized RLOO advantages implementation + rlhf_reward = rlhf_reward.reshape(args.rloo_k, -1) + baseline = (rlhf_reward.sum(0) - rlhf_reward) / (args.rloo_k - 1) + advantages = rlhf_reward - baseline + advantages = advantages.flatten() + + # Normalize advantages + if args.normalize_advantage: + advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) + + empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + + # Get batch data + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + + # Forward pass + output = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + + # Compute new logprobs + new_logprobs = selective_log_softmax(logits, mb_responses) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + + # Compute probability ratios + new_ratio = (new_logprobs - mb_logprobs).exp() + new_logprobs = new_logprobs.sum(1) + mb_logprobs = mb_logprobs.sum(1) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + + # PPO clipped loss + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = pg_loss_max.mean() + + # Final loss + loss = pg_loss + + # Optimization step + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + + with torch.no_grad(): + pg_clipfrac = (pg_losses2 > pg_losses).float().mean() + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = new_ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + + # del everything and empty cache + # fmt: off + del ( + output, logits, new_logprobs, logprobs_diff, ratio, pg_losses, + pg_losses2, pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, + mb_advantage, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + empty_cache() + + # Compute metrics + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / (args.rloo_k * self.train_dataset_len) # used by self.log + self.log(metrics) + del kl, mean_kl, mean_entropy, scores + + self.lr_scheduler.step() + self.state.global_step += 1 + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None, metrics=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + + if isinstance(self.reward_model, nn.Module): + _, score, _ = get_reward( + self.reward_model, + postprocessed_query_response, + processing_class.pad_token_id, + context_length, + ) + else: + score = torch.tensor( + self.reward_model( + processing_class.batch_decode(postprocessed_query_response, skip_special_tokens=True) + ), + dtype=torch.float, + ).to(postprocessed_query_response.device) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + if is_rich_available(): + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @inproceedings{ahmadian2024back, + title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}}, + author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker}, + year = 2024, + booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024}, + publisher = {Association for Computational Linguistics}, + pages = {12248--12267}, + editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="RLOO", + trainer_citation=citation, + paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs", + paper_id="2402.14740", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py new file mode 100644 index 0000000000000000000000000000000000000000..89fb6189712b06dd6e70912f0e4c6c28b0019528 --- /dev/null +++ b/trl/trainer/sft_config.py @@ -0,0 +1,233 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Any, Optional + +from transformers import TrainingArguments + + +@dataclass +class SFTConfig(TrainingArguments): + r""" + Configuration class for the [`SFTTrainer`]. + + This class includes only the parameters that are specific to SFT training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`SFTTrainer`] is provided as a string. + + > Parameters that control the data preprocessing + + dataset_text_field (`str`, *optional*, defaults to `"text"`): + Name of the column that contains text data in the dataset. + dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): + Dictionary of optional keyword arguments for the dataset preparation. The only supported key is + `skip_prepare_dataset`. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + eos_token (`str` or `None`, *optional*, defaults to `None`): + Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`. + pad_token (`int` or `None`, *optional*, defaults to `None`): + Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`, + it falls back to `processing_class.eos_token`. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right. + If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. + packing (`bool`, *optional*, defaults to `False`): + Whether to group multiple sequences into fixed-length blocks to improve computational efficiency and reduce + padding. Uses `max_length` to define sequence length. + packing_strategy (`str`, *optional*, defaults to `"ffd"`): + Strategy for packing sequences. Can be either `"ffd"` (first-fit decreasing, default), or `"wrapped"`. + padding_free (`bool`, *optional*, defaults to `False`): + Whether to perform forward passes without padding by flattening all sequences in the batch into a single + continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only + supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened + batch structure. When packing is enabled with strategy `"ffd"`, padding-free is enabled, regardless of the + value of this parameter. + pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + eval_packing (`bool` or `None`, *optional*, defaults to `None`): + Whether to pack the eval dataset. If `None`, uses the same value as `packing`. + + > Parameters that control the training + + completion_only_loss (`bool` or `None`, *optional*, defaults to `None`): + Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed + only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If + `False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: + loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on + the full sequence for [language modeling](#language-modeling) datasets. + activation_offloading (`bool`, *optional*, defaults to `False`): + Whether to offload the activations to the CPU. + """ + + _VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=2e-5, + metadata={"help": "The initial learning rate for AdamW."}, + ) + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + average_tokens_across_devices: bool = field( + default=True, + metadata={ + "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize " + "num_tokens_in_batch for precise loss calculation. Reference: https://github.com/huggingface/transformers/issues/34242 " + }, + ) + + # Parameters that control the model + model_init_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `SFTTrainer` is provided as a string." + }, + ) + + # Parameters that control the data preprocessing + dataset_text_field: str = field( + default="text", + metadata={"help": "Name of the column that contains text data in the dataset."}, + ) + dataset_kwargs: Optional[dict[str, Any]] = field( + default=None, + metadata={ + "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " + "`skip_prepare_dataset`." + }, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + eos_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`." + }, + ) + pad_token: Optional[str] = field( + default=None, + metadata={ + "help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that " + "is also `None`, it falls back to `processing_class.eos_token`." + }, + ) + max_length: Optional[int] = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from" + "the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " + "sequence length." + }, + ) + packing: bool = field( + default=False, + metadata={ + "help": "Whether to group multiple sequences into fixed-length blocks to improve computational efficiency " + "and reduce padding. Uses `max_length` to define sequence length." + }, + ) + packing_strategy: str = field( + default="ffd", + metadata={ + "help": "Strategy for packing sequences. Can be either `'ffd'` (first-fit decreasing, default), or " + "`'wrapped'`." + }, + ) + padding_free: bool = field( + default=False, + metadata={ + "help": "Whether to perform forward passes without padding by flattening all sequences in the batch into " + "a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, " + "this is only supported with the `flash_attention_2` attention implementation, which can efficiently " + "handle the flattened batch structure. When packing is enabled with strategy `'ffd'`, padding-free is " + "enabled, regardless of the value of this parameter." + }, + ) + pad_to_multiple_of: Optional[int] = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + eval_packing: Optional[bool] = field( + default=None, + metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, + ) + + # Parameters that control the training + completion_only_loss: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is " + "computed only on the completion, which is supported only for prompt-completion datasets. If `False`, " + "loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: " + "loss is computed on the completion for prompt-completion datasets, and on the full sequence for " + "language modeling datasets." + ) + }, + ) + activation_offloading: bool = field( + default=False, + metadata={"help": "Whether to offload the activations to the CPU."}, + ) + + # Deprecated parameters + max_seq_length: Optional[int] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.20.0. Use `max_length` instead." + }, + ) + + def __post_init__(self): + super().__post_init__() + + if self.max_seq_length is not None: + warnings.warn( + "`max_seq_length` is deprecated and will be removed in version 0.20.0. Use `max_length` instead.", + DeprecationWarning, + ) + self.max_length = self.max_seq_length diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..fb92f3bb1fa3e8d97a649f68d0b5593e3fb2bee3 --- /dev/null +++ b/trl/trainer/sft_trainer.py @@ -0,0 +1,829 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import dataclasses +import os +import warnings +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +from accelerate import PartialState +from datasets import Dataset, IterableDataset +from packaging import version +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BaseImageProcessor, + DataCollator, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + Trainer, + TrainingArguments, + is_wandb_available, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ..data_utils import ( + is_conversational, + maybe_convert_to_chatml, + pack_dataset, + truncate_dataset, +) +from ..models import get_act_offloading_ctx_manager +from .sft_config import SFTConfig +from .utils import ( + ConstantLengthDataset, + generate_model_card, + get_comet_experiment_url, + pad, + peft_module_casting_to_bf16, +) + + +if is_peft_available(): + import peft + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + +if is_wandb_available(): + import wandb + + +@dataclass +class DataCollatorForLanguageModeling(DataCollatorMixin): + """ + Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch if + they are not all of the same length. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + completion_only_loss (`bool`, *optional*, defaults to `True`): + When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens + that are no in the completion. + padding_free (`bool`, *optional*, defaults to `False`): + If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be + generated accordingly. The attention mask will be set to 1 for all tokens. + pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + + Examples: + ```python + >>> from trl import DataCollatorForLanguageModeling + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0) + >>> examples = [ + ... {"input_ids": [1, 2, 3]}, + ... {"input_ids": [4, 5]} + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[ 1, 2, 3], + [ 4, 5, -100]])} + >>> # With completion mask + >>> examples = [ + ... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]}, + ... {"input_ids": [4, 5], "completion_mask": [0, 1]} + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3], + [ 4, 5, 0]]), + 'attention_mask': tensor([[ 1, 1, 1], + [ 1, 1, 0]]), + 'position_ids': tensor([[0, 1, 2], + [0, 1, 0]]), + 'labels': tensor([[-100, 2, 3], + [-100, 5, -100]])} + + >>> # With padding_free + >>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True) + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1]]), + 'position_ids': tensor([[0, 1, 2, 0, 1]]), + 'labels': tensor([[1, 2, 3, 4, 5]])} + ``` + """ + + pad_token_id: int + completion_only_loss: bool = True + padding_free: bool = False + return_position_ids: bool = True + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + # Convert to tensor + input_ids = [torch.tensor(example["input_ids"]) for example in examples] + attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids] + if self.return_position_ids: + if "position_ids" in examples[0]: + position_ids = [torch.tensor(example["position_ids"]) for example in examples] + else: + position_ids = [torch.arange(len(ids)) for ids in input_ids] + labels = [torch.tensor(example["input_ids"]) for example in examples] + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = [torch.tensor(example["completion_mask"]) for example in examples] + + # Pad + output = {} + if self.padding_free: + output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0) + output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0) + if self.return_position_ids: + output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0) + output["labels"] = torch.cat(labels, dim=0).unsqueeze(0) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0) + output["labels"][completion_mask == 0] = -100 + + else: + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.return_position_ids: + output["position_ids"] = pad( + position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"] = pad( + labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + if self.completion_only_loss and "completion_mask" in examples[0]: + completion_mask = pad( + completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of + ) + output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion + + return output + + +class SFTTrainer(Trainer): + """ + Trainer for Supervised Fine-Tuning (SFT) method. + + This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods. + + Example: + + ```python + from datasets import load_dataset + from trl import SFTTrainer + + dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]") + + trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset) + trainer.train() + ``` + + Args: + model (`Union[str, PreTrainedModel]`): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or + a path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is + loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments + in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + args ([`SFTConfig`], *optional*, defaults to `None`): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator (`DataCollator`, *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`DataCollatorForLanguageModeling`]. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. SFT supports both [language modeling](#language-modeling) type and + [prompt-completion](#prompt-completion) type. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + + The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field. + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoTokenizer.from_pretrained`]. + callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): + List of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your + model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + optimizer_cls_and_kwargs (`Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*, defaults to `None`): + A tuple containing the optimizer class and keyword arguments to use. + Overrides `optim` and `optim_args` in `args`. Incompatible with the `optimizers` argument. + + Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before initializing the Trainer. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*, defaults to `None`): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + formatting_func (`Optional[Callable]`): + Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly + converts the dataset into a [language modeling](#language-modeling) type. + """ + + _tag_names = ["trl", "sft"] + + def __init__( + self, + model: Union[str, nn.Module, PreTrainedModel], + args: Optional[Union[SFTConfig, TrainingArguments]] = None, + data_collator: Optional[DataCollator] = None, # type: ignore + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + compute_loss_func: Optional[Callable] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None, + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + peft_config: Optional["PeftConfig"] = None, + formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None, + ): + # Args + model_id = model if isinstance(model, str) else model.config._name_or_path + if args is None: + model_name = model_id.split("/")[-1] + args = SFTConfig(f"{model_name}-SFT") + elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig): + dict_args = args.to_dict() + dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token + dict_args.pop("push_to_hub_token") + args = SFTConfig(**dict_args) + + # Handle the tokenizer + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model_id) + + if args.eos_token is not None: + eos_token = args.eos_token + eos_token_id = processing_class.convert_tokens_to_ids(eos_token) + if eos_token_id is None: + raise ValueError( + f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " + "in the vocabulary before using it as an EOS token." + ) + processing_class.eos_token_id = eos_token_id + + # Model + if args.model_init_kwargs is not None and not isinstance(model, str): + warnings.warn( + "You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + if isinstance(model, str): + model = self._create_model_from_path(model, args) + + # PEFT configuration and model wrapping + if peft_config is not None: + model = self._prepare_peft_model(model, peft_config, args) + + # Data collator + # FFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing + # FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask. + self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "ffd") + if self.padding_free: + if data_collator is not None: + raise ValueError("Passing a custom data collator is not supported when using padding-free.") + if args.packing and args.packing_strategy == "wrapped": + warnings.warn( + "You are passing `padding_free=True` with the 'wrapped' packing strategy, which is not " + "recommended. Please refer to the documentation to understand why this is not recommended." + ) + if model.config._attn_implementation != "flash_attention_2": + warnings.warn( + "Padding-free training is enabled, but the attention implementation is not set to " + "'flash_attention_2'. Padding-free training flattens batches into a single sequence, and " + "'flash_attention_2' is the only known attention mechanism that reliably supports this. Using " + "other implementations may lead to unexpected behavior. To ensure compatibility, set " + "`attn_implementation='flash_attention_2'` in the model configuration, or verify that your " + "attention mechanism can handle flattened sequences." + ) + if args.per_device_train_batch_size == 1 and not args.packing: + warnings.warn( + "You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size " + "of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size " + "to at least 2." + ) + + if args.completion_only_loss is None: + first_example = next(iter(train_dataset)) + self.completion_only_loss = "prompt" in first_example + else: + self.completion_only_loss = args.completion_only_loss + + if data_collator is None: + # Get the pad token: if not provided, use the one from the processing class or the eos token + # if the processing class does not have a pad token. + pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token + pad_token_id = processing_class.convert_tokens_to_ids(pad_token) + if pad_token_id is None: + raise ValueError( + f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " + f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " + "in the vocabulary before using it as a padding token." + ) + data_collator = DataCollatorForLanguageModeling( + pad_token_id=pad_token_id, + completion_only_loss=self.completion_only_loss, + padding_free=self.padding_free, + # Using position_ids without flash_attn hurts the training + return_position_ids=model.config._attn_implementation == "flash_attention_2", + pad_to_multiple_of=args.pad_to_multiple_of, + ) + + if ( + args.packing + and args.packing_strategy == "ffd" + and model.config._attn_implementation != "flash_attention_2" + ): + warnings.warn( + "You are using packing, but the attention implementation is not set to 'flash_attention_2'. Packing " + "flattens batches into a single sequence, and 'flash_attention_2' is the only known attention " + "mechanism that reliably supports this. Using other implementations may lead to cross-contamination " + "between batches. To avoid this, either disable packing by setting `packing=False`, or set " + "`attn_implementation='flash_attention_2'` in the model configuration." + ) + + # Dataset + preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False) + if preprocess_dataset: + if self.completion_only_loss and formatting_func: + raise ValueError( + "A formatting function was provided while `completion_only_loss=True`, which is incompatible. " + "Using a formatter converts the dataset to a language modeling type, conflicting with " + "completion-only loss. To resolve this, apply your formatting function before passing the " + "dataset, or disable `completion_only_loss` in `SFTConfig`." + ) + + train_dataset = self._prepare_dataset( + train_dataset, processing_class, args, args.packing, formatting_func, "train" + ) + if eval_dataset is not None: + packing = args.packing if args.eval_packing is None else args.eval_packing + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset( + eval_dataset, processing_class, args, packing, formatting_func, "eval" + ) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Initialize the Trainer. Parent class will handle: + # - DeepSpeed configuration (through create_accelerator_and_postprocess) + # - FSDP setup + # - Distributed training setup + # - Optimizer and scheduler creation + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_loss_func=compute_loss_func, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + optimizer_cls_and_kwargs=optimizer_cls_and_kwargs, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + # Initialize activation offloading context + if self.args.activation_offloading: + self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model) + else: + self.maybe_activation_offload_context = contextlib.nullcontext() + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTrainedModel: + """Creates a model from a path or model identifier.""" + model_init_kwargs = args.model_init_kwargs or {} + # Handle torch dtype + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + # Disable caching if gradient checkpointing is enabled (not supported) + # if args.gradient_checkpointing: + # model_init_kwargs["use_cache"] = False + + # Create model + model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return model + + def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: + """Prepares a model for PEFT training.""" + if not is_peft_available(): + raise ImportError("To use PeftModel, you need to install the `peft` library.") + + if not isinstance(peft_config, PeftConfig): + raise ValueError( + f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need " + "to pass a PeftConfig object to the SFTTrainer." + ) + + if isinstance(model, PeftModel): + return model + + # Handle quantized models (QLoRA) + is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) + + is_sharded_qlora = False + if getattr(model, "is_loaded_in_4bit", False): + # Check if model is sharded (FSDP/DS-Zero3) + for _, param in model.named_parameters(): + if param.__class__.__name__ == "Params4bit": + is_sharded_qlora = param.data.device.type in {"cpu", "meta"} + break + + # Prepare model for kbit training if needed + if is_qlora and not is_sharded_qlora: + model = self._prepare_model_for_kbit_training(model, args) + # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training + args = dataclasses.replace(args, gradient_checkpointing=False) + elif args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Create PEFT model + if ( + version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) + + # Handle bf16 casting for 4-bit models + if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: + peft_module_casting_to_bf16(model) + + return model + + def _prepare_model_for_kbit_training(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Prepares a quantized model for kbit training.""" + prepare_model_kwargs = { + "use_gradient_checkpointing": args.gradient_checkpointing, + "gradient_checkpointing_kwargs": args.gradient_checkpointing_kwargs or {}, + } + + return prepare_model_for_kbit_training(model, **prepare_model_kwargs) + + def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: SFTConfig) -> PreTrainedModel: + """Enables gradient checkpointing for the model.""" + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + use_reentrant = ( + "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] + ) + + if use_reentrant: + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + return model + + def _prepare_dataset( + self, + dataset: Union[Dataset, IterableDataset], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], + args: SFTConfig, + packing: bool, + formatting_func: Optional[Callable[[dict], str]], + dataset_name: str, + ) -> Union[Dataset, IterableDataset]: + # Convert the dataset to an IterableDataset if it is a ConstantLengthDataset + if isinstance(dataset, ConstantLengthDataset): + return dataset + + # If the dataset is already preprocessed (tokenized), skip the processing steps. + column_names = list(next(iter(dataset)).keys()) + is_processed = "input_ids" in column_names + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Apply the formatting function if any + if formatting_func is not None and is_processed: + warnings.warn( + "You passed a dataset that is already processed (contains an `input_ids` field) together with a " + "formatting function. Therefore `formatting_func` will be ignored. Either remove the " + "`formatting_func` or pass a dataset that is not already processed.", + UserWarning, + ) + + if formatting_func is not None and not is_processed: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset" + + def _func(example): + return {"text": formatting_func(example)} + + try: + dataset = dataset.map(_func, batched=False, **map_kwargs) + except Exception as e: + warnings.warn( + f"Failed to apply the formatting function due to the following error: {e}. This may be " + "because the function is designed for batched input. Please update it to process one example " + "at a time (i.e., accept and return a single example). For now, we will attempt to apply the " + "function in batched mode, but note that batched formatting is deprecated and will be removed " + "in version 0.21.", + DeprecationWarning, + ) + dataset = dataset.map(_func, batched=True, **map_kwargs) + + if not is_processed: + # Convert the dataset to ChatML if needed + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML" + column_names = next(iter(dataset)).keys() + dataset = dataset.map( + maybe_convert_to_chatml, + remove_columns="conversations" if "conversations" in column_names else None, + **map_kwargs, + ) + + # Apply the chat template if needed + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if "text" in example and not example["text"].endswith(eos_token): # language modeling case + example["text"] = example["text"] + eos_token + elif "completion" in example and not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example + + dataset = dataset.map( + add_eos, + fn_kwargs={"eos_token": processing_class.eos_token}, + remove_columns="messages" if "messages" in column_names else None, # renamed to "text" + **map_kwargs, + ) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize(example, processing_class, dataset_text_field): + if "prompt" in example: # prompt-completion case + if is_conversational(example): + prompt_ids = processing_class.apply_chat_template(example["prompt"]) + prompt_completion_ids = processing_class.apply_chat_template( + example["prompt"] + example["completion"] + ) + else: + prompt_ids = processing_class(text=example["prompt"]).input_ids + prompt_completion_ids = processing_class( + text=example["prompt"] + example["completion"] + ).input_ids + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + warnings.warn( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + # Create a completion mask + completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids)) + processed = {"input_ids": prompt_completion_ids, "completion_mask": completion_mask} + + else: # language modeling case + if is_conversational(example): + processed = {"input_ids": processing_class.apply_chat_template(example["messages"])} + else: + processed = {"input_ids": processing_class(text=example[dataset_text_field]).input_ids} + return processed + + dataset = dataset.map( + tokenize, + fn_kwargs={ + "processing_class": processing_class, + "dataset_text_field": args.dataset_text_field, + }, + **map_kwargs, + ) + + # Pack or truncate + if packing: + if args.max_length is None: + raise ValueError("When packing is enabled, `max_length` can't be `None`.") + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Packing {dataset_name} dataset" + dataset = dataset.select_columns("input_ids") + # Packing adds new column "position_ids" needed for document aware flash attention + dataset = pack_dataset(dataset, args.max_length, args.packing_strategy, map_kwargs) + elif args.max_length is not None: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Truncating {dataset_name} dataset" + dataset = truncate_dataset(dataset, args.max_length, map_kwargs) + # For Liger kernel, ensure only input_ids is present + if args.use_liger_kernel: + dataset = dataset.select_columns({"input_ids", "position_ids"}.intersection(dataset.column_names)) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the + # dataset. So we need to override the default signature columns to include "completion_mask" as well. + if self._signature_columns is None: + self._signature_columns = ["input_ids", "attention_mask", "position_ids", "completion_mask"] + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + mode = "train" if self.model.training else "eval" + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + if mode == "train": + # When using padding-free, the attention_mask is not present in the inputs, instead we have cu_seq_lens_q, + # cu_seq_lens_k, and max_length_k, max_length_q and position_ids. + if "attention_mask" in inputs: + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + elif "position_ids" in inputs: + local_num_tokens = torch.tensor(inputs["position_ids"].size(1), device=inputs["position_ids"].device) + num_tokens_in_batch = self.accelerator.gather_for_metrics(local_num_tokens).sum().item() + else: + raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Compute token accuracy if we have labels and if the model is not using Liger (no logits) + if "labels" in inputs and not self.args.use_liger_kernel: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() + + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + + # Compute the mean token accuracy and log it + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + + return (loss, outputs) if return_outputs else loss + + # Override training step to add activation offloading context. + def training_step(self, *args, **kwargs): + with self.maybe_activation_offload_context: + return super().training_step(*args, **kwargs) + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=list(tags), + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="SFT", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..84239f7b6f8fc156ab0e014f4dc997f313daaa9c --- /dev/null +++ b/trl/trainer/utils.py @@ -0,0 +1,1864 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import importlib.resources as pkg_resources +import json +import random +import warnings +from collections import deque +from dataclasses import dataclass, field +from importlib.metadata import version +from typing import Any, Literal, Optional, Union + +import datasets +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import torch.utils.data +from accelerate import Accelerator, PartialState +from accelerate.state import AcceleratorState +from huggingface_hub import ModelCard, ModelCardData +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import IterableDataset +from transformers import ( + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + EvalPrediction, + GenerationConfig, + PreTrainedTokenizerBase, + TrainerState, + TrainingArguments, + is_comet_available, +) +from transformers.utils import ( + is_peft_available, + is_rich_available, + is_torch_mlu_available, + is_torch_npu_available, + is_torch_xpu_available, +) + +from ..trainer.model_config import ModelConfig + + +if is_rich_available(): + from rich.console import Console + from rich.panel import Panel + from rich.table import Table + from rich.text import Text + +if is_comet_available(): + import comet_ml + +if is_peft_available(): + from peft import LoraConfig, PeftConfig + + +class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): + """ + Data collator used for completion tasks. It ensures that all the tokens of the labels are set to an 'ignore_index' + when they do not come from the assistant. This ensure that the loss is only + calculated on the completion made by the assistant. + + Args: + response_template (`Union[str, list[int]]`): the template form that indicates the start of the response, typically something like + '### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response + differently if it does not have proper context. + instruction_template (`Union[str, list[int]]`): the template form that indicates the start of the human instruction, typically something like + '### Human:\n'. Useful for assistant-style conversation datasets. It can also be passed as tokenized ids. + mlm (`bool`, *optional*, defaults to `False`): Whether to use masked language modeling in the underlying + `DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present + for flexibility and backwards-compatibility. + ignore_index (`int`, *optional*, defaults to `-100`): + The index to use to ignore the initial tokens with + """ + + def __init__( + self, + response_template: Union[str, list[int]], + instruction_template: Optional[Union[str, list[int]]] = None, + *args, + mlm: bool = False, + ignore_index: int = -100, + padding_free: bool = False, + **kwargs, + ): + super().__init__(*args, mlm=mlm, **kwargs) + warnings.warn( + "This class is deprecated and will be removed in version 0.20.0. To train on completion only, please use " + "the parameter `completion_only_loss` of `SFTConfig` instead.", + DeprecationWarning, + ) + + self.instruction_template = instruction_template + if isinstance(instruction_template, str): + # The user provides a string, must tokenize + self.instruction_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.instruction_token_ids = instruction_template + + self.response_template = response_template + if isinstance(response_template, str): + # The user provides a string, must tokenize + self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False) + else: + # The user already provides the token ids + self.response_token_ids = response_template + + if not self.mlm and self.instruction_template and self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + warnings.warn( + "The pad_token_id and eos_token_id values of this tokenizer are identical. " + "If you are planning for multi-turn training, " + "it can result in the model continuously generating questions and answers without eos token. " + "To avoid this, set the pad_token_id to a different value.", + UserWarning, + ) + + self.ignore_index = ignore_index + self.padding_free = padding_free + + def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: + batch = super().torch_call(examples) + + if self.instruction_template is None: + for i in range(len(examples)): + response_token_ids_start_idx = None + + for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # `response_token_ids` is `'### Response:\n'`, here we are just making sure that the token IDs match + if ( + self.response_token_ids + == batch["labels"][i][idx : idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_start_idx = idx + + if response_token_ids_start_idx is None: + warnings.warn( + f"Could not find response key `{self.response_template}` in the following instance: " + f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " + "calculation. Note, if this happens often, consider increasing the `max_length`.", + UserWarning, + ) + batch["labels"][i, :] = self.ignore_index + else: + response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids) + + # Make pytorch loss function ignore all tokens up through the end of the response key + batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index + + else: + for i in range(len(examples)): + response_token_ids_idxs = [] + human_token_ids_idxs = [] + + for assistant_idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]: + # find the indexes of the start of a response. + if ( + self.response_token_ids + == batch["labels"][i][assistant_idx : assistant_idx + len(self.response_token_ids)].tolist() + ): + response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids)) + + if len(response_token_ids_idxs) == 0: + warnings.warn( + f"Could not find response key `{self.response_template}` in the following instance: " + f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " + "calculation. Note, if this happens often, consider increasing the `max_length`.", + UserWarning, + ) + batch["labels"][i, :] = self.ignore_index + + human_token_ids = self.instruction_token_ids + for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]: + # find the indexes of the start of a human answer. + if human_token_ids == batch["labels"][i][human_idx : human_idx + len(human_token_ids)].tolist(): + human_token_ids_idxs.append(human_idx) + + if len(human_token_ids_idxs) == 0: + warnings.warn( + f"Could not find instruction key `{self.instruction_template}` in the following instance: " + f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss " + "calculation. Note, if this happens often, consider increasing the `max_length`.", + UserWarning, + ) + batch["labels"][i, :] = self.ignore_index + + if ( + len(human_token_ids_idxs) > 0 + and len(response_token_ids_idxs) > 0 + and human_token_ids_idxs[0] > response_token_ids_idxs[0] + ): + human_token_ids_idxs = [0] + human_token_ids_idxs + + for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): + # Make pytorch loss function ignore all non response tokens + if idx != 0: + batch["labels"][i, start:end] = self.ignore_index + else: + batch["labels"][i, :end] = self.ignore_index + + if len(response_token_ids_idxs) < len(human_token_ids_idxs): + batch["labels"][i, human_token_ids_idxs[-1] :] = self.ignore_index + + if self.padding_free: + # remove padding, `attention_mask` and add `position_ids` + attn_mask = batch.pop("attention_mask") + batch["input_ids"] = batch["input_ids"][attn_mask.bool()].unsqueeze(0) + batch["position_ids"] = attn_mask.cumsum(1)[attn_mask.bool()].unsqueeze(0) - 1 + batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0) + batch["labels"][batch["position_ids"] == 0] = self.ignore_index + + # Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations. + flattened_position_ids = batch["position_ids"].flatten() + indices_q = torch.arange( + flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32 + ) + batch["cu_seq_lens_q"] = torch.cat( + ( + indices_q[flattened_position_ids == 0], + torch.tensor( + flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32 + ), + ) + ).unsqueeze(0) + batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"] + + # Determine maximum sequence lengths to prevent graph breaks during further computations. + batch["max_length_k"] = torch.tensor([flattened_position_ids.max().item() + 1]) + batch["max_length_q"] = batch["max_length_k"] + + return batch + + +@dataclass +class DataCollatorForChatML: + """ + Data collator for ChatML format datasets. + """ + + tokenizer: PreTrainedTokenizerBase + ignore_index: int = -100 + max_length: int = None + prompt_key: str = "prompt" + messages_key: str = "messages" + + def __post_init__(self): + if self.tokenizer.pad_token_id is None: + raise ValueError("The tokenizer does not have a pad token. Please set `pad_token_id` in the tokenizer.") + if self.max_length is None: + # set a sensible default + self.max_length = min(self.tokenizer.model_max_length, 1024) + + def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]: + input_ids = [] + attention_mask = [] + prompts_input_ids = [] + prompt_attention_mask = [] + labels = [] + + for example in examples: + formatted_prompt = example.get(self.prompt_key, None) + if formatted_prompt is None: + prompt = example[self.messages_key][:-1] + formatted_prompt = self.tokenizer.apply_chat_template( + prompt, tokenize=False, add_generation_prompt=True + ) + + if "input_ids" not in example: + message = example[self.messages_key] + formatted_message = self.tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=False + ) + tokenized_message = self.tokenizer( + formatted_message, + truncation=True, + max_length=self.max_length, + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + input_ids.append(tokenized_message["input_ids"]) + if "attention_mask" in example: + attention_mask.append(tokenized_message["attention_mask"]) + else: + attention_mask.append([1] * len(tokenized_message["input_ids"])) + else: + input_ids.append(example["input_ids"]) + if "attention_mask" in example: + attention_mask.append(example["attention_mask"]) + else: + attention_mask.append([1] * len(example["input_ids"])) + + tokenized_prompt = self.tokenizer( + formatted_prompt, + truncation=True, + max_length=len(input_ids[-1]), + padding=False, + return_tensors=None, + add_special_tokens=False, + ) + + prompts_input_ids.append(tokenized_prompt["input_ids"]) + prompt_attention_mask.append(tokenized_prompt["attention_mask"]) + + # Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index + label = [self.ignore_index] * len(input_ids[-1]) + completion_start_idx = len(tokenized_prompt["input_ids"]) + label[completion_start_idx:] = input_ids[-1][completion_start_idx:] + labels.append(label) + + # convert to list of tensors and pad + input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids] + attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask] + labels = [torch.tensor(label, dtype=torch.long) for label in labels] + input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) + attention_mask = pad(attention_mask, padding_side="left", padding_value=0) + labels = pad(labels, padding_side="left", padding_value=self.ignore_index) + + prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids] + prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask] + prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id) + prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "prompts": prompts_input_ids, + "prompt_attention_mask": prompt_attention_mask, + } + + +@dataclass +class RewardDataCollatorWithPadding: + r""" + Reward DataCollator class that pads the inputs to the maximum length of the batch. + + Args: + tokenizer (`PreTrainedTokenizerBase`): + The tokenizer used for encoding the data. + padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`): + padding_strategy to pass to the tokenizer. + pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`): + If set will pad the sequence to a multiple of the provided value. + return_tensors (`str`, `optional`, defaults to `"pt"`): + The tensor type to use. + """ + + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str] = True + pad_to_multiple_of: Optional[int] = None + return_tensors: str = "pt" + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + features_chosen = [] + features_rejected = [] + margin = [] + # check if we have a margin. If we do, we need to batch it as well + has_margin = "margin" in features[0] + for feature in features: + # check if the keys are named as expected + if ( + "input_ids_chosen" not in feature + or "input_ids_rejected" not in feature + or "attention_mask_chosen" not in feature + or "attention_mask_rejected" not in feature + ): + raise ValueError( + "The features should include `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`" + ) + + features_chosen.append( + { + "input_ids": feature["input_ids_chosen"], + "attention_mask": feature["attention_mask_chosen"], + } + ) + features_rejected.append( + { + "input_ids": feature["input_ids_rejected"], + "attention_mask": feature["attention_mask_rejected"], + } + ) + if has_margin: + margin.append(feature["margin"]) + batch_chosen = self.tokenizer.pad( + features_chosen, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch_rejected = self.tokenizer.pad( + features_rejected, + padding=self.padding, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=self.return_tensors, + ) + batch = { + "input_ids_chosen": batch_chosen["input_ids"], + "attention_mask_chosen": batch_chosen["attention_mask"], + "input_ids_rejected": batch_rejected["input_ids"], + "attention_mask_rejected": batch_rejected["attention_mask"], + "return_loss": True, + } + if has_margin: + margin = torch.tensor(margin, dtype=torch.float) + batch["margin"] = margin + return batch + + +def pad( + tensors: list[torch.Tensor], + padding_value: int = 0, + padding_side: str = "right", + pad_to_multiple_of: Optional[int] = None, +) -> torch.Tensor: + """ + Pads a list of tensors to the same shape along the first dimension. + + Args: + tensors (`list[torch.Tensor]`): + List of input tensors to pad. + padding_value (`int`): + Value to use for padding. Default is 0. + padding_side (`str`): + Side on which to add padding. Must be 'left' or 'right'. Default is 'right'. + pad_to_multiple_of (`int`, *optional*, defaults to `None`): + If set will pad the sequence to a multiple of the provided value. + + Returns: + `torch.Tensor`: + A single tensor containing the padded tensors. + + Examples: + >>> import torch + >>> pad([torch.tensor([1, 2, 3]), torch.tensor([4, 5])]) + tensor([[1, 2, 3], + [4, 5, 0]]) + >>> pad([torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])]) + tensor([[[1, 2], + [3, 4]], + + [[5, 6], + [0, 0]]]) + """ + # Determine the maximum shape for each dimension + output_shape = np.max([t.shape for t in tensors], 0).tolist() + + # Apply pad_to_multiple_of to the first (sequence) dimension + if pad_to_multiple_of is not None: + remainder = output_shape[0] % pad_to_multiple_of + if remainder != 0: + output_shape[0] += pad_to_multiple_of - remainder + + # Create an output tensor filled with the padding value + output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device) + + for i, t in enumerate(tensors): + if padding_side == "left": + seq_start = output_shape[0] - t.shape[0] + elif padding_side == "right": + seq_start = 0 + else: + raise ValueError("padding_side must be 'left' or 'right'") + + # Define the slices + seq_slice = slice(seq_start, seq_start + t.shape[0]) + slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:]) + output[i][slices] = t + + return output + + +@dataclass +class DPODataCollatorWithPadding: + r""" + DPO DataCollator class that pads the tokenized inputs to the maximum length of the batch. + + Args: + pad_token_id (`int` defaults to 0): + The tokenizer's pad_token_id. + label_pad_token_id (`int`, defaults to -100): + The label used for masking. + is_encoder_decoder (`bool` or `None`, `optional`, defaults to `None`): + Whether you model has an encoder_decoder architecture. + """ + + pad_token_id: int = 0 + label_pad_token_id: int = -100 + is_encoder_decoder: Optional[bool] = False + + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + # first, pad everything to the same length + padded_batch = {} + for k in features[0].keys(): + if k.endswith(("_input_ids", "_attention_mask", "_labels", "_pixel_values")): + if self.is_encoder_decoder: + to_pad = [torch.LongTensor(ex[k]) for ex in features] + + if (k.startswith("prompt")) and (k.endswith("input_ids")): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" + " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k): + padding_value = self.label_pad_token_id + else: + raise ValueError(f"Unexpected key in batch '{k}'") + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + else: + # Set padding value based on the key + if k.endswith("_input_ids"): + if self.pad_token_id is None: + raise ValueError( + "Padding is enabled, but the tokenizer is not configured with a padding token." + " Explicitly set `tokenizer.pad_token` (e.g. `tokenizer.pad_token = tokenizer.eos_token`)" + " before calling the trainer." + ) + padding_value = self.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = 0 + elif k.endswith("_pixel_values"): + padding_value = 0 # TODO: check if this is correct + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + # Set padding side based on the key + if k in ["prompt_input_ids", "prompt_attention_mask"]: + padding_side = "left" + else: + padding_side = "right" + + # Set the dtype + if k.endswith("_pixel_values"): + dtype = torch.float32 # will be downcasted if necessary by the Trainer + else: + dtype = torch.int64 + + # Convert to tensor and pad + to_pad = [torch.tensor(ex[k], dtype=dtype) for ex in features] + padded_batch[k] = pad(to_pad, padding_value=padding_value, padding_side=padding_side) + elif k.endswith("_logps"): + # the cached reference model logprobs + padded_batch[k] = torch.tensor([ex[k] for ex in features]) + else: + padded_batch[k] = [ex[k] for ex in features] + + return padded_batch + + +class ConstantLengthDataset(IterableDataset): + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + The dataset also formats the text before tokenization with a specific format that is provided + by the user. + + Args: + tokenizer (`transformers.PreTrainedTokenizer`): + The processor used for processing the data. + dataset (`dataset.Dataset`): + Dataset with text files. + dataset_text_field (`str` or `None`, *optional*, defaults to `None`): + Name of the field in the dataset that contains the text. Only one of `dataset_text_field` and + `formatting_func` should be provided. + formatting_func (`Callable`, *optional*): + Function that formats the text before tokenization. Usually it is recommended to follow a certain + pattern such as `"### Question: {question} ### Answer: {answer}"`. Only one of `dataset_text_field` and + `formatting_func` should be provided. + infinite (`bool`, *optional*, defaults to `False`): + If True the iterator is reset after dataset reaches end else stops. + seq_length (`int`, *optional*, defaults to `1024`): + Length of token sequences to return. + num_of_sequences (`int`, *optional*, defaults to `1024`): + Number of token sequences to keep in buffer. + chars_per_token (`int`, *optional*, defaults to `3.6`): + Number of characters per token used to estimate number of tokens in text buffer. + eos_token_id (`int`, *optional*, defaults to `0`): + Id of the end of sequence token if the passed tokenizer does not have an EOS token. + shuffle (`bool`, *optional*, defaults to `True`) + Shuffle the examples before they are returned + append_concat_token (`bool`, *optional*, defaults to `True`) + If true, appends `eos_token_id` at the end of each sample being packed. + add_special_tokens (`bool`, *optional*, defaults to `True`) + If true, tokenizers adds special tokens to each sample being packed. + """ + + def __init__( + self, + tokenizer, + dataset, + dataset_text_field=None, + formatting_func=None, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + eos_token_id=0, + shuffle=True, + append_concat_token=True, + add_special_tokens=True, + ): + warnings.warn( + "This class is deprecated and will be removed in version 0.20.0. To use packing, use the argument " + "`packing` of `SFTConfig` instead.", + DeprecationWarning, + ) + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.shuffle = shuffle + self.append_concat_token = append_concat_token + self.add_special_tokens = add_special_tokens + + if dataset_text_field is not None and formatting_func is not None: + warnings.warn( + "Only one of `dataset_text_field` and `formatting_func` should be provided. " + "Ignoring `dataset_text_field` and using `formatting_func`.", + UserWarning, + ) + + if formatting_func is not None: + self.formatting_func = formatting_func + elif dataset_text_field is not None: + self.formatting_func = lambda x: x[dataset_text_field] + else: # neither is provided + raise ValueError("Either `dataset_text_field` or `formatting_func` should be provided.") + + self.pretokenized = False + column_names = ( + dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None + ) + if column_names is not None and "input_ids" in column_names: + self.pretokenized = True + # since the dataset is tokenized, the unit of buffer size should be tokens + self.max_buffer_size = seq_length * num_of_sequences + + def __len__(self): + return len(self.dataset) + + def __iter__(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(self.formatting_func(next(iterator))) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + if self.shuffle: + random.shuffle(buffer) + if self.pretokenized: + tokenized_inputs = buffer + else: + tokenized_inputs = self.tokenizer( + buffer, add_special_tokens=self.add_special_tokens, truncation=False + )["input_ids"] + all_token_ids = [] + for tokenized_input in tokenized_inputs: + if self.append_concat_token: + tokenized_input = tokenized_input + [self.concat_token_id] + all_token_ids.extend(tokenized_input) + examples = [] + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + if len(input_ids) == self.seq_length: + examples.append(input_ids) + if self.shuffle: + # Shuffle again, otherwise split examples occur in consecutive tensors. + random.shuffle(examples) + for example in examples: + self.current_size += 1 + yield { + "input_ids": torch.LongTensor(example), + "labels": torch.LongTensor(example), + } + + +@dataclass +class RunningMoments: + """ + Calculates the running mean and standard deviation of a data stream. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75 + """ + + accelerator: Accelerator + mean: float = 0 + std: float = 1 + var: float = 1 + count: float = 1e-24 + + @torch.no_grad() + def update(self, xs: torch.Tensor) -> tuple[float, float]: + """ + Updates running moments from batch's moments computed across ranks + """ + if self.accelerator.use_distributed: + xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + xs_mean, xs_var = xs_mean.float(), xs_var.float() + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + new_sum = xs_var * xs_count + # correct old_sum deviation accounting for the new mean + old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count + tot_sum = old_sum + new_sum + + self.mean += (delta * xs_count / tot_count).item() + new_var = tot_sum / tot_count + self.std = (new_var * tot_count / (tot_count - 1)).float().sqrt().item() + self.var = new_var.item() + self.count = tot_count + + return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item() + + def save_to_json(self, json_path: str): + """Save the content of this instance in JSON format inside `json_path`.""" + # save everything except accelerator + if self.accelerator.is_main_process: + save_dict = dataclasses.asdict(self, dict_factory=lambda x: {k: v for (k, v) in x if k != "accelerator"}) + json_string = json.dumps(save_dict, indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, accelerator: Accelerator, json_path: str): + """Create an instance from the content of `json_path`.""" + # load everything except accelerator + with open(json_path, encoding="utf-8") as f: + text = f.read() + return cls(accelerator=accelerator, **json.loads(text)) + + +@torch.no_grad() +def get_global_statistics( + accelerator, xs: torch.Tensor, mask=None, device="cpu" +) -> tuple[torch.Tensor, torch.Tensor, int]: + """ + Computes element-wise mean and variance of the tensor across processes. Reference: + https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 + """ + xs = xs.to(accelerator.device) + sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device) + sum_and_count = accelerator.reduce(sum_and_count) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask)) + sum_var = accelerator.reduce(sum_var) + global_var = sum_var / count + + return global_mean.to(device), global_var.to(device), count.item() + + +def compute_accuracy(eval_pred: EvalPrediction) -> dict[str, float]: + predictions, labels = eval_pred + if predictions.ndim == 3: + # Token classification task. Shapes are (batch_size, seq_len, num_labels) and (batch_size, seq_len) + # Used to compute the accuracy in the prm_trainer. + predictions = np.argmax(predictions, axis=2) + + # Flatten the predictions and labels to remove the ignored tokens. + predictions = np.array( + [p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100] + ) + labels = np.array([lbl for label in labels for lbl in label if lbl != -100]) + + else: + # Here, predictions is rewards_chosen and rewards_rejected. Shapes are (batch_size, 2) and (batch_size,) + # We want to see how much of the time rewards_chosen > rewards_rejected. + equal_mask = predictions[:, 0] == predictions[:, 1] + equal_predictions_count = int(equal_mask.sum()) + + if equal_predictions_count > 0: + warnings.warn( + f"There are {equal_predictions_count} out of {len(predictions[:, 0])} instances where the predictions " + "for both options are equal. These instances are ignored in the accuracy computation.", + UserWarning, + ) + + # Filter out equal predictions + predictions = predictions[~equal_mask] + labels = labels[~equal_mask] + + # Use the remaining predictions for accuracy calculation + predictions = np.argmax(predictions, axis=1) + + accuracy = np.array(predictions == labels, dtype=float).mean().item() + return {"accuracy": accuracy} + + +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0 + + +def exact_div(a, b, custom_error_message=""): + q = a // b + if a != q * b: + raise ValueError(f"{custom_error_message}, inexact division: {a} / {b} = {a / b}") + return q + + +# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5 +class PerPromptStatTracker: + r""" + Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm + + Args: + buffer_size (`int`): + Size of the buffer to keep for each prompt. + min_count (`int`): + Minimum number of samples to keep in the buffer before calculating the mean and std. + """ + + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + prompts = np.array(prompts) + rewards = np.array(rewards) + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} + + +def peft_module_casting_to_bf16(model): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.LayerNorm) or "norm" in name: + module = module.to(torch.float32) + elif any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): + if hasattr(module, "weight"): + if module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16) + + +def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]: + if model_args.load_in_4bit: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.torch_dtype, # For consistency with model weights, we use the same value as `torch_dtype` + bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, + bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, + bnb_4bit_quant_storage=model_args.torch_dtype, + ) + elif model_args.load_in_8bit: + quantization_config = BitsAndBytesConfig( + load_in_8bit=True, + ) + else: + quantization_config = None + + return quantization_config + + +def get_kbit_device_map() -> Optional[dict[str, int]]: + if torch.cuda.is_available() or is_torch_xpu_available(): + return {"": PartialState().local_process_index} + else: + return None + + +def get_peft_config(model_args: ModelConfig) -> "Optional[PeftConfig]": + if model_args.use_peft is False: + return None + + if not is_peft_available(): + raise ValueError( + "You need to have PEFT library installed in your environment, make sure to install `peft`. " + "Make sure to run `pip install -U peft`." + ) + + peft_config = LoraConfig( + task_type=model_args.lora_task_type, + r=model_args.lora_r, + target_modules=model_args.lora_target_modules, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout, + bias="none", + use_rslora=model_args.use_rslora, + use_dora=model_args.use_dora, + modules_to_save=model_args.lora_modules_to_save, + ) + + return peft_config + + +def get_exp_cap(value, decimal=4): + """ + Get the exponent cap of a value. This is used to cap the exponent of a value to avoid overflow. + The formula is : log(value.dtype.max) + E.g. + For float32 data type, the maximum exponent value is 88.7228 to 4 decimal points. + ``` + + Args: + value (`torch.Tensor`): + The input tensor to obtain the data type + decimal (`int`): + The number of decimal points of the output exponent cap. + eg: direct calling exp(log(torch.float32.max)) will result in inf + so we cap the exponent to 88.7228 to avoid overflow. + """ + vdtype_max = torch.zeros([1]).to(value.dtype) + torch.finfo(value.dtype).max + vdtype_log_max = torch.log(vdtype_max).to(value.device) + return torch.floor(vdtype_log_max * 10**decimal) / 10**decimal if decimal > 0 else vdtype_log_max + + +def cap_exp(value, cap=-1): + # Cap the exponent value below the upper-bound to avoid overflow, before calling torch.exp + cap = get_exp_cap(value) if cap < 0 else cap + return torch.exp(torch.clamp(value, max=cap)) + + +def print_rich_table(df: pd.DataFrame) -> None: + if not is_rich_available(): + raise ImportError( + "The function `print_rich_table` requires the `rich` library. Please install it with `pip install rich`." + ) + console = Console() + table = Table(show_lines=True) + for column in df.columns: + table.add_column(column) + for _, row in df.iterrows(): + table.add_row(*row.astype(str).tolist()) + console.print(table) + + +SIMPLE_SFT_CHAT_TEMPLATE = "{% for message in messages %}{{' ' + message['content']}}{% endfor %}{{ eos_token }}" +# SIMPLE_SFT_CHAT_TEMPLATE simply ends things with an EOS token, this helps the SFT model learn to end the completions with EOS tokens + +SIMPLE_CHAT_TEMPLATE = "{% for message in messages %}{{message['role'].capitalize() + ': ' + message['content'] + '\n\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}" + + +@dataclass +class OnlineTrainerState(TrainerState): + episode: int = 0 + + +@dataclass +class OnPolicyConfig(TrainingArguments): + r""" + Base configuration class for on-policy trainers. + + This class includes only the parameters that are specific to some on-policy training. For a full list of training + arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this + class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + run_name (`str` or `None`, *optional*, defaults to `None`): + Name of the run. + dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): + Number of processes to use for processing the dataset. + num_mini_batches (`int`, *optional*, defaults to `1`): + Number of minibatches to split a batch into. + total_episodes (`int` or `None`, *optional*, defaults to `None`): + Total number of episodes in the dataset. + local_rollout_forward_batch_size (`int`, *optional*, defaults to `64`): + Per rank no grad forward pass in the rollout phase. + num_sample_generations (`int`, *optional*, defaults to `10`): + Number of debugging samples generations (i.e., `generate_completions` calls) throughout training. + response_length (`int`, *optional*, defaults to `53`): + Length of the response. + stop_token (`str` or `None`, *optional*, defaults to `None`): + Specifies the stop token to use for text generation. This parameter is mutually exclusive with + `stop_token_id`. + + - `None`: No stop token is applied, unless `stop_token_id` is specified. + - `'eos'`: Uses the tokenizer's `eos_token`. + + stop_token_id (`int` or `None`, *optional*, defaults to `None`): + Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is applied, + unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`. + temperature (`float`, *optional*, defaults to `0.7`): + Sampling temperature. + missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): + Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage + to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive + value. + sft_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the SFT model. + world_size (`int` or `None`, *optional*, defaults to `None`): + Number of processes (GPUs) to use for the training. + num_total_batches (`int` or `None`, *optional*, defaults to `None`): + Number of total batches to train. + micro_batch_size (`int` or `None`, *optional*, defaults to `None`): + Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`). + local_batch_size (`int` or `None`, *optional*, defaults to `None`): + Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`). + batch_size (`int` or `None`, *optional*, defaults to `None`): + Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * `gradient_accumulation_steps`). + local_mini_batch_size (`int` or `None`, *optional*, defaults to `None`): + Mini batch size per GPU. + mini_batch_size (`int` or `None`, *optional*, defaults to `None`): + Mini batch size across GPUs. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the model to the Hub after training. + """ + + # Parameters whose default values are overridden from TrainingArguments + logging_steps: float = field( + default=10, + metadata={ + "help": ( + "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " + "If smaller than 1, will be interpreted as ratio of total training steps." + ) + }, + ) + bf16: bool = field( + default=True, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA " + "architecture or using CPU (use_cpu) or Ascend NPU. This is an experimental API and it may change." + ) + }, + ) + + run_name: Optional[str] = field( + default=None, + metadata={"help": "Name of the run."}, + ) + dataset_num_proc: Optional[int] = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + num_mini_batches: int = field( + default=1, + metadata={"help": "Number of minibatches to split a batch into."}, + ) + total_episodes: Optional[int] = field( + default=None, + metadata={"help": "Total number of episodes in the dataset."}, + ) + local_rollout_forward_batch_size: int = field( + default=64, + metadata={"help": "Per rank no grad forward pass in the rollout phase."}, + ) + num_sample_generations: int = field( + default=10, + metadata={ + "help": "Number of debugging samples generations (i.e., `generate_completions` calls) throughout training." + }, + ) + response_length: int = field( + default=53, + metadata={"help": "Length of the response."}, + ) + stop_token: Optional[Literal["eos"]] = field( + default=None, + metadata={ + "help": "Specifies the stop token to use for text generation. This parameter is mutually exclusive with " + "`stop_token_id`." + }, + ) + stop_token_id: Optional[int] = field( + default=None, + metadata={ + "help": "Specifies the ID of the stop token to use for text generation. If `None`, no stop token ID is " + "applied, unless `stop_token` is specified. This parameter is mutually exclusive with `stop_token`." + }, + ) + temperature: float = field( + default=0.7, + metadata={"help": "Sampling temperature."}, + ) + missing_eos_penalty: Optional[float] = field( + default=None, + metadata={ + "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " + "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " + "a positive value." + }, + ) + sft_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the SFT model."}, + ) + world_size: Optional[int] = field( + default=None, + metadata={"help": "Number of processes (GPUs) to use for the training."}, + ) + num_total_batches: Optional[int] = field( + default=None, + metadata={"help": "Number of total batches to train."}, + ) + micro_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)."}, + ) + local_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Batch size per GPU (HF's `per_device_train_batch_size` * `gradient_accumulation_steps`)."}, + ) + batch_size: Optional[int] = field( + default=None, + metadata={ + "help": "Batch size across devices (HF's `per_device_train_batch_size` * `world_size` * " + "`gradient_accumulation_steps`)." + }, + ) + local_mini_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Mini batch size per GPU."}, + ) + mini_batch_size: Optional[int] = field( + default=None, + metadata={"help": "Mini batch size across GPUs."}, + ) + push_to_hub: bool = field( + default=False, + metadata={"help": "Whether to push the model to the Hub after training."}, + ) + + +def first_true_indices(bools: torch.Tensor, dtype=torch.long): + """ + Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving + the position of the first True in each "row". + + Returns the length of the rows (bools.size(-1)) if no element is True in a given row. + + Args: + bools (`torch.Tensor`): + An N-dimensional boolean tensor. + dtype (`torch.dtype`, optional): + The desired data type of the output tensor. Defaults to `torch.long`. + + Returns: + `torch.Tensor`: + An (N-1)-dimensional tensor of integers indicating the position of the first True + in each row. If no True value is found in a row, returns the length of the row. + """ + row_len = bools.size(-1) + zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device) + return torch.min(zero_or_index, dim=-1).values + + +def get_reward( + model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the reward logits and the rewards for a given model and query responses. + + Args: + model (`torch.nn.Module`): + The model used to compute the reward logits. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + pad_token_id (`int`): + The token ID representing the pad token. + context_length (`int`): + The length of the context in the query responses. + + Returns: + tuple: + - `reward_logits` (`torch.Tensor`): + The logits for the reward model. + - `final_rewards` (`torch.Tensor`): + The final rewards for each query response. + - `sequence_lengths` (`torch.Tensor`): + The lengths of the sequences in the query responses. + """ + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum + lm_backbone = getattr(model, model.base_model_prefix) + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + output = lm_backbone( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + use_cache=False, # otherwise mistral-based RM would error out + ) + reward_logits = model.score(output.hidden_states[-1]) + sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length + # https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454 + return ( + reward_logits, + reward_logits[ + torch.arange(reward_logits.size(0), device=reward_logits.device), + sequence_lengths, + ].squeeze(-1), + sequence_lengths, + ) + + +def forward( + model: torch.nn.Module, + query_responses: torch.Tensor, + pad_token_id: int, +) -> torch.nn.Module: + """ + Performs a forward pass through the model with the given query responses and pad token ID. + + Args: + model (`torch.nn.Module`): + The model to perform the forward pass. + query_responses (`torch.Tensor`): + The tensor containing the query responses. + pad_token_id (`int`): + The token ID representing the pad token. + + Returns: + `torch.nn.Module`: + The output of the model, including hidden states. + """ + attention_mask = query_responses != pad_token_id + position_ids = attention_mask.cumsum(1) - attention_mask.long() + input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) + return model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + return_dict=True, + output_hidden_states=True, + ) + + +def prepare_deepspeed( + model: torch.nn.Module, per_device_train_batch_size: int, fp16: bool = False, bf16: bool = False +): + """ + Prepares the model for training with DeepSpeed (both for stage 2 and 3), configuring the appropriate settings based on the model and + batch size. + + Args: + model (`torch.nn.Module`): + The model to be prepared for DeepSpeed training. + per_device_train_batch_size (`int`): + The training batch size per device. + + Returns: + `torch.nn.Module`: + The model initialized and configured with DeepSpeed for training. + """ + import deepspeed + + deepspeed_plugin = AcceleratorState().deepspeed_plugin + config_kwargs = deepspeed_plugin.deepspeed_config + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["train_micro_batch_size_per_gpu"] = per_device_train_batch_size + config_kwargs = { + "train_micro_batch_size_per_gpu": config_kwargs["train_micro_batch_size_per_gpu"], + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + if bf16: + config_kwargs["bf16"] = {"enabled": True} + elif fp16: + config_kwargs["fp16"] = {"enabled": True} + else: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` + # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0, + } + ) + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + +def truncate_response(stop_token_id: int, pad_token_id: int, responses: torch.Tensor): + """ + Truncates the responses at the first occurrence of the stop token, filling the rest with pad tokens. + + Args: + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs. + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses. + responses (`torch.Tensor`): + The tensor containing the responses to be truncated. + + Returns: + `torch.Tensor`: + The truncated responses tensor with pad tokens filled after the stop token. + """ + trunc_idxs = first_true_indices(responses == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(responses.size()) - 1) + [responses.shape[1]] + idxs = torch.arange(responses.shape[1], device=responses.device).view(*new_size) + postprocessed_responses = torch.masked_fill(responses, idxs > trunc_idxs, pad_token_id) + return postprocessed_responses + + +def generate( + lm_backbone: torch.nn.Module, queries: torch.Tensor, pad_token_id: int, generation_config: GenerationConfig +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Generates sequences from the language model backbone in a way that does not affect padding tokens. + + Args: + lm_backbone (`torch.nn.Module`): + The language model backbone used for generation. + queries (`torch.Tensor`): + The tensor containing the input queries. + pad_token_id (`int`): + The token ID representing the pad token. + generation_config (`GenerationConfig`): + The configuration for the generation process. + + Returns: + tuple: + - `generated_sequences` (`torch.Tensor`): + The concatenated tensor of input queries and generated sequences. + - `logits` (`torch.Tensor`): + The logits output from the generation process. + """ + context_length = queries.shape[1] + attention_mask = queries != pad_token_id + input_ids = torch.masked_fill(queries, ~attention_mask, 0) + output = lm_backbone.generate( + input_ids=input_ids, + attention_mask=attention_mask, + # position_ids=attention_mask.cumsum(1) - attention_mask.long(), # not needed: already adjusted in generations + # https://github.com/huggingface/transformers/blob/ac33aeeeee2a7a89b89c93c2962e6feb90daef0a/src/transformers/models/gpt2/modeling_gpt2.py#L1227-L1250 + generation_config=generation_config, + return_dict_in_generate=True, + output_scores=True, + ) + logits = torch.stack(output.scores, 1) + return torch.cat((queries, output.sequences[:, context_length:]), dim=1), logits + + +@torch.no_grad() +def batch_generation( + model: torch.nn.Module, + queries: torch.Tensor, + local_rollout_forward_batch_size: int, + pad_token_id: int, + generation_config: GenerationConfig, +): + query_responses = [] + logitss = [] + batch_size = queries.shape[0] + for i in range(0, batch_size, local_rollout_forward_batch_size): + query = queries[i : i + local_rollout_forward_batch_size] + query_response, logits = generate( + model, + query, + pad_token_id, + generation_config, + ) + query_responses.append(query_response) + logitss.append(logits) + + # padding tensors + padded_query_responses = pad(query_responses, padding_value=pad_token_id, padding_side="right") + padded_logitss = pad(logitss, padding_value=0, padding_side="right") + + # reshaping + padded_query_responses = padded_query_responses.view(-1, padded_query_responses.shape[-1])[:batch_size] + padded_logitss = padded_logitss.view(-1, *padded_logitss.shape[2:])[:batch_size] + + return padded_query_responses, padded_logitss + + +def add_bos_token_if_needed( + bos_token_id: Optional[int], + prompt_len_input_ids: int, + prompt_tokens: dict[str, list[int]], + chosen_prompt_len_input_ids: int, + chosen_tokens: dict[str, list[int]], + rejected_prompt_len_input_ids: int, + rejected_tokens: dict[str, list[int]], +): + if bos_token_id is not None: + if prompt_len_input_ids == 0 or bos_token_id != prompt_tokens["prompt_input_ids"][0]: + prompt_tokens["prompt_input_ids"] = [bos_token_id] + prompt_tokens["prompt_input_ids"] + prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"] + if chosen_prompt_len_input_ids == 0 or bos_token_id != chosen_tokens["prompt_input_ids"][0]: + chosen_tokens["prompt_input_ids"] = [bos_token_id] + chosen_tokens["prompt_input_ids"] + chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"] + if rejected_prompt_len_input_ids == 0 or bos_token_id != rejected_tokens["prompt_input_ids"][0]: + rejected_tokens["prompt_input_ids"] = [bos_token_id] + rejected_tokens["prompt_input_ids"] + rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"] + return prompt_tokens, chosen_tokens, rejected_tokens + + +def add_eos_token_if_needed( + eos_token_id: int, chosen_tokens: dict[str, list[int]], rejected_tokens: dict[str, list[int]] +): + if len(chosen_tokens["input_ids"]) == 0 or eos_token_id != chosen_tokens["input_ids"][-1]: + chosen_tokens["input_ids"].append(eos_token_id) + chosen_tokens["attention_mask"].append(1) + if len(rejected_tokens["input_ids"]) == 0 or eos_token_id != rejected_tokens["input_ids"][-1]: + rejected_tokens["input_ids"].append(eos_token_id) + rejected_tokens["attention_mask"].append(1) + return chosen_tokens, rejected_tokens + + +def truncate_right( + input_ids: torch.Tensor, stop_token_id: int, pad_token_id: int +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Truncates the input tensor from the right side after the first occurrence of the stop token. + + Args: + input_ids (`torch.Tensor`): + The tensor containing the responses to be truncated + stop_token_id (`int`): + The token ID representing the stop token where truncation occurs + pad_token_id (`int`): + The token ID representing the pad token used to fill the truncated responses + + Returns: + tuple: + - `output_ids` (`torch.Tensor`): + The truncated responses tensor with pad tokens filled after the stop token + - `mask` (`torch.Tensor`): + The mask tensor to indicate the padding tokens + """ + trunc_idxs = first_true_indices(input_ids == stop_token_id).unsqueeze(-1) + new_size = [1] * (len(input_ids.size()) - 1) + [input_ids.shape[1]] + idxs = torch.arange(input_ids.shape[1], device=input_ids.device).view(*new_size) + output_ids = torch.masked_fill(input_ids, idxs > trunc_idxs, pad_token_id) + mask = torch.masked_fill(torch.ones_like(input_ids), idxs > trunc_idxs, 0) + return output_ids, mask + + +def empty_cache() -> None: + """Empties the cache of the available torch device. + + This function checks for the availability of different torch devices (XPU, MLU, NPU, CUDA) + and empties the cache of the first available device it finds. + + If none of the specific devices are available, it defaults to emptying the CUDA cache. + """ + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_mlu_available(): + torch.mlu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + + +def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenizerBase) -> list[str]: + """ + Decodes the input tensor and strips the padding tokens. + + Args: + inputs (`torch.Tensor`): + The input tensor to be decoded. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer used to decode the input tensor. + + Returns: + `list[str]`: + The list of decoded strings with padding tokens stripped. + """ + decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) + return [d.replace(tokenizer.pad_token, "") for d in decoded] + + +def generate_model_card( + base_model: Optional[str], + model_name: str, + hub_model_id: str, + dataset_name: Optional[str], + tags: list[str], + wandb_url: Optional[str], + trainer_name: str, + trainer_citation: Optional[str] = None, + paper_title: Optional[str] = None, + paper_id: Optional[str] = None, + comet_url: Optional[str] = None, +) -> ModelCard: + """ + Generate a `ModelCard` from a template. + + Args: + base_model (`str` or `None`): + Base model name. + model_name (`str`): + Model name. + hub_model_id (`str`): + Hub model ID as `username/model_id`. + dataset_name (`str` or `None`): + Dataset name. + tags (`list[str]`): + Tags. + wandb_url (`str` or `None`): + Weights & Biases run URL. + comet_url (`str` or `None`): + Comet experiment URL. + trainer_name (`str`): + Trainer name. + trainer_citation (`str` or `None`, defaults to `None`): + Trainer citation as a BibTeX entry. + paper_title (`str` or `None`, defaults to `None`): + Paper title. + paper_id (`str` or `None`, defaults to `None`): + ArXiv paper ID as `YYMM.NNNNN`. + + Returns: + `ModelCard`: + A ModelCard object. + """ + card_data = ModelCardData( + base_model=base_model, + datasets=dataset_name, + library_name="transformers", + licence="license", + model_name=model_name, + tags=["generated_from_trainer", *tags], + ) + card = ModelCard.from_template( + card_data, + template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), + base_model=base_model, + model_name=model_name, + hub_model_id=hub_model_id, + dataset_name=dataset_name, + wandb_url=wandb_url, + comet_url=comet_url, + trainer_name=trainer_name, + trainer_citation=trainer_citation, + paper_title=paper_title, + paper_id=paper_id, + trl_version=version("trl"), + transformers_version=version("transformers"), + pytorch_version=version("torch"), + datasets_version=version("datasets"), + tokenizers_version=version("tokenizers"), + ) + return card + + +def get_comet_experiment_url() -> Optional[str]: + """ + If Comet integration is enabled, return the URL of the current Comet experiment; otherwise, return `None`. + """ + if not is_comet_available(): + return None + + if comet_ml.get_running_experiment() is not None: + return comet_ml.get_running_experiment().url + + return None + + +def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None: + """ + If Comet integration is enabled logs a table to the Comet experiment if it is currently running. + + Args: + name (`str`): + Table name. + table (`pd.DataFrame`): + The Pandas DataFrame containing the table to log. + """ + if not is_comet_available(): + raise ModuleNotFoundError("The comet-ml is not installed. Please install it first: pip install comet-ml") + + experiment = comet_ml.get_running_experiment() + if experiment is not None: + experiment.log_table(tabular_data=table, filename=name) + + +def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + """ + Shift non-zero elements in the mask and corresponding tensors to the left. + + This function operates on a binary mask and any number of additional tensors with the same dimensions as the mask. + For each row, non-zero values are shifted to the leftmost positions. Then, columns that contain only zeros across + all rows are truncated from the mask and tensors. Visually, this operation can be represented as follows: + + ``` + [[0, 0, x, x, x, x], -> [[x, x, x, x], + [0, x, x, x, 0, 0]] [x, x, x, 0]] + ``` + + Args: + + mask (`torch.Tensor`): + 2D tensor (binary mask) with shape `(N, M)`. + *tensors (`torch.Tensor`) + One or more 2D tensors with the same shape as `mask`. These tensors will be processed alongside `mask`, + with non-zero values shifted and excess zero columns truncated in the same manner. + + Returns: + `torch.Tensor`: + Updated binary mask with non-zero values flushed to the left and trailing zero columns removed. + `*torch.Tensor` + Updated tensors, processed in the same way as the mask. + + Example: + ```python + >>> mask = torch.tensor([[0, 0, 1, 1, 1], + ... [0, 1, 1, 0, 0]]) + >>> tensor = torch.tensor([[9, 9, 2, 3, 4], + ... [9, 5, 6, 9, 9]]) + >>> new_mask, new_tensor = flush_left(mask, tensor) + >>> print(new_mask) + tensor([[1, 1, 1], + [1, 1, 0]]) + >>> print(new_tensor) + tensor([[2, 3, 4], + [5, 6, 0]]) + ``` + """ + _, M = mask.shape + + # Create copy of mask and tensors + mask_copy = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the left + first_non_zero = mask_copy.argmax(dim=1) + pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) + idx_roll = (pos + first_non_zero.unsqueeze(1)) % M + mask_roll = mask_copy.gather(1, idx_roll) + rolled_tensors = [t.gather(1, idx_roll) for t in tensors] + + # Truncate trailing columns that are all zeros in mask_roll + col_sums = mask_roll.sum(dim=0) + empty_cols = col_sums == 0 + first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M + flushed_mask = mask_roll[:, :first_empty_col] + flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors] + + if not flushed_tensors: + return flushed_mask + return flushed_mask, *flushed_tensors + + +def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: + """ + Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details. + """ + _, M = mask.shape + + # Create copy of mask and tensors + mask_copy = mask.clone() + tensors = [t.clone() for t in tensors] + + # Shift non-zero values to the right + flipped_mask = torch.fliplr(mask_copy) + first_non_zero = flipped_mask.argmax(dim=1) + pos = torch.arange(M, device=mask_copy.device).unsqueeze(0) + idx_roll = (pos - first_non_zero.unsqueeze(1)) % M + mask_roll = mask_copy.gather(1, idx_roll) + rolled_tensors = [t.gather(1, idx_roll) for t in tensors] + + # Truncate leading columns that are all zeros in mask_roll + col_sums = mask_roll.sum(dim=0) + non_empty_cols = col_sums != 0 + first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M + flushed_mask = mask_roll[:, first_non_empty_col:] + flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors] + + if not flushed_tensors: + return flushed_mask + return flushed_mask, *flushed_tensors + + +def selective_log_softmax(logits, index): + """ + A memory-efficient implementation of the common `log_softmax -> gather` operation. + + This function is equivalent to the following naive implementation: + ```python + logps = torch.gather(logits.log_softmax(-1), dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + ``` + + Args: + logits (`torch.Tensor`): + Logits tensor of shape `(..., num_classes)`. + index (`torch.Tensor`): + Index tensor of shape `(...)`, specifying the positions to gather from the log-softmax output. + + Returns: + `torch.Tensor`: + Gathered log probabilities with the same shape as `index`. + """ + if logits.dtype in [torch.float32, torch.float64]: + selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1) + # loop to reduce peak mem consumption + logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits]) + per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) + else: + # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach + per_token_logps = [] + for row_logits, row_labels in zip(logits, index): # loop to reduce peak mem consumption + row_logps = F.log_softmax(row_logits, dim=-1) + row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) + per_token_logps.append(row_per_token_logps) + per_token_logps = torch.stack(per_token_logps) + return per_token_logps + + +def print_prompt_completions_sample( + prompts: list[str], + completions: list[str], + rewards: dict[str, list[float]], + advantages: list[float], + step: int, + num_samples: int = None, +) -> None: + """ + Print out a sample of model completions to the console with multiple reward metrics. + + This function creates a nicely formatted table showing prompt-completion pairs, useful for monitoring model outputs + during training. It requires the `rich` library to be installed. + + Args: + prompts (`list[str]`): + List of prompts. + completions (`list[str]`): + List of completions corresponding to the prompts. + rewards (`dict[str, list[float]]`): + Dictionary where keys are reward names and values are lists of rewards. + advantages (`list[float]`): + List of advantages corresponding to the prompts and completions. + step (`int`): + Current training step number, used in the output title. + num_samples (`int` or `None`, *optional*, defaults to `None`): + Number of random samples to display. If `None` (default), all items will be displayed. + + Example: + ```python + >>> from trl.trainer.utils import print_prompt_completions_sample + >>> prompts = ["The sky is", "The sun is"] + >>> completions = [" blue.", " in the sky."] + >>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]} + >>> advantages = [0.987, 0.654] + >>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42) + ╭──────────────────────────── Step 42 ─────────────────────────────╮ + │ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │ + │ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │ + │ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │ + │ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │ + │ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │ + │ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │ + │ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │ + ╰──────────────────────────────────────────────────────────────────╯ + ``` + """ + if not is_rich_available(): + raise ImportError( + "The function `print_prompt_completions_sample` requires the `rich` library. Please install it with " + "`pip install rich`." + ) + console = Console() + table = Table(show_header=True, header_style="bold white", expand=True) + + # Add columns + table.add_column("Prompt", style="bright_yellow") + table.add_column("Completion", style="bright_green") + for reward_name in rewards.keys(): + table.add_column(reward_name, style="bold cyan", justify="right") + table.add_column("Advantage", style="bold magenta", justify="right") + + # Some basic input validation + if num_samples is not None: + if num_samples >= len(prompts): + num_samples = None + elif num_samples <= 0: + return + + # Subsample data if num_samples is specified + if num_samples is not None: + indices = random.sample(range(len(prompts)), num_samples) + prompts = [prompts[i] for i in indices] + completions = [completions[i] for i in indices] + rewards = {key: [val[i] for i in indices] for key, val in rewards.items()} + advantages = [advantages[i] for i in indices] + + for i in range(len(prompts)): + reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals + table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values, f"{advantages[i]:.2f}") + table.add_section() # Adds a separator between rows + + panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white") + console.print(panel) diff --git a/trl/trainer/xpo_config.py b/trl/trainer/xpo_config.py new file mode 100644 index 0000000000000000000000000000000000000000..78c1e562128208e0b6a2ba6e78f9becd59e8a1ac --- /dev/null +++ b/trl/trainer/xpo_config.py @@ -0,0 +1,44 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from trl.trainer.online_dpo_config import OnlineDPOConfig + + +@dataclass +class XPOConfig(OnlineDPOConfig): + r""" + Configuration class for the [`XPOTrainer`]. + + Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): + Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch + and the last alpha is used for the rest of the epochs. + """ + + alpha: list[float] = field( + default_factory=lambda: [1e-5], + metadata={ + "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each " + "new epoch and the last alpha is used for the rest of the epochs." + }, + ) + + def __post_init__(self): + super().__post_init__() + if hasattr(self.alpha, "__len__") and len(self.alpha) == 1: + self.alpha = self.alpha[0] diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..b77ebee7416ec20153bb19c54b45dab82c5e2b27 --- /dev/null +++ b/trl/trainer/xpo_trainer.py @@ -0,0 +1,589 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import textwrap +from typing import Any, Callable, Optional, Union + +import jinja2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + is_apex_available, + is_wandb_available, +) +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_peft_available + +from ..data_utils import is_conversational, maybe_apply_chat_template +from ..models.utils import unwrap_model_for_generation +from .judges import BasePairwiseJudge +from .online_dpo_trainer import OnlineDPOTrainer +from .utils import ( + SIMPLE_CHAT_TEMPLATE, + empty_cache, + generate_model_card, + get_comet_experiment_url, + get_reward, + selective_log_softmax, + truncate_right, +) +from .xpo_config import XPOConfig + + +if is_apex_available(): + from apex import amp + + +if is_wandb_available(): + import wandb + + +if is_peft_available(): + from peft import PeftModel + + +class XPOTrainer(OnlineDPOTrainer): + r""" + Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`]. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + reward_model (`transformers.PreTrainedModel`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + judge (`BasePairwiseJudge`): + The judge to use for pairwise comparison of model completions. + args (`XPOConfig`): + The XPO config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*): + Processing class used to process the data. If provided, will be used to automatically process the inputs + for the model, and it will be saved along the model to make it easier to rerun an interrupted training or + reuse the fine-tuned model. + peft_config (`dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`list[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "xpo"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Union[PreTrainedModel, nn.Module] = None, + reward_model: Optional[nn.Module] = None, + judge: Optional[BasePairwiseJudge] = None, + args: Optional[XPOConfig] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + processing_class: Optional[ + Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] + ] = None, + peft_config: Optional[dict] = None, + compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + judge=judge, + reward_model=reward_model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + reward_processing_class=processing_class, # for now, XPOTrainer can't use any reward model + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._alpha = self.args.alpha + + # Overwrite the stats dictionary to include XPO specific statistics + self.stats = { + # Remove "non_score_reward", "rlhf_reward", "scores" + # Add "loss/dpo", "loss/xpo" + "loss/dpo": [], + "loss/xpo": [], + "objective/kl": [], + "objective/entropy": [], + "rewards/chosen": [], + "rewards/rejected": [], + "rewards/accuracies": [], + "rewards/margins": [], + "logps/chosen": [], + "logps/rejected": [], + # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token" + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "alpha": [], + "beta": [], + } + if self.reward_model is not None: + # Replace "scores" by "model_scores" and "ref_scores" + self.stats["objective/model_scores"] = [] + self.stats["objective/ref_scores"] = [] + self.stats["objective/scores_margin"] = [] + + @property + def alpha(self): + if isinstance(self._alpha, list): + epoch = self.state.epoch + return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1] + else: + return self._alpha + + def _generate_completions(self, prompts, model): + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen: + model_output = unwrapped_policy_model_for_gen.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + actual_model_for_ref_generation: torch.nn.Module + if self.ref_model is None: + unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model) + + if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel): + actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model() + else: + actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic + else: + actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model) + + with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen: + ref_output = final_ref_model_for_gen.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, ref_output + + def _process_completions(self, model_output, ref_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + "raw": prompts["raw"], + } + + # Process reference model completions + ref_completion_ids = ref_output[:, context_length:] + ref_completion_ids, ref_completion_mask = truncate_right( + ref_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id + ) + ref_data = { + "input_ids": torch.cat((prompts["input_ids"], ref_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], ref_completion_mask), dim=1), + "raw": prompts["raw"], + } + + return model_data, ref_data + + def _compute_rewards(self, model_data, ref_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + _, ref_scores, _ = get_reward( + self.reward_model, ref_data["input_ids"], self.processing_class.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + ref_contain_eos = torch.any(ref_data["input_ids"] == self.processing_class.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + ref_scores[~ref_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, ref_scores + + def _compute_judge(self, model_data, ref_data, context_length): + prompts = model_data["raw"] + model_data_completions = self.processing_class.batch_decode( + model_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + model_data_completions = [completion.strip() for completion in model_data_completions] + + ref_data_completions = self.processing_class.batch_decode( + ref_data["input_ids"][:, context_length:], skip_special_tokens=True + ) + ref_data_completions = [completion.strip() for completion in ref_data_completions] + + if is_conversational({"prompt": prompts[0]}): + model_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in model_data_completions + ] + environment = jinja2.Environment() + template = environment.from_string(SIMPLE_CHAT_TEMPLATE) + prompts = [template.render(messages=message) for message in prompts] + model_data_completions = [template.render(messages=completion) for completion in model_data_completions] + + ref_data_completions = [ + [{"role": "assistant", "content": completion}] for completion in ref_data_completions + ] + ref_data_completions = [template.render(messages=completion) for completion in ref_data_completions] + + ranks_of_first_completion = self.judge.judge( + prompts, + list(zip(model_data_completions, ref_data_completions)), + ) + # convert ranks to a True/False mask: + # when rank == 0, it means the first completion is the best + # when rank == 1, it means the second completion is the best + return torch.tensor([rank == 0 for rank in ranks_of_first_completion], device=model_data["input_ids"].device) + + def _compute_logprobs(self, model, model_data, ref_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:]) + return token_logprobs + + # Compute logprobs for model completions + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + # Compute logprobs for model on reference completions (for XPO loss) + model_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) + + # Compute logprobs for reference model completions + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + ref_padding_mask = ref_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + model_logprobs_ref_data = model_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) + ref_logprobs_ref_data = ref_logprobs_ref_data.masked_fill(ref_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data + + def _compute_losses( + self, + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + ): + # Compute log probs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) + ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs + + rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs + + # Compute logits as the difference between chosen and rejected log ratios + logits = chosen_log_ratios - rejected_log_ratios + + if self.args.loss_type == "sigmoid": + dpo_losses = -F.logsigmoid(self.beta * logits) + elif self.args.loss_type == "ipo": + dpo_losses = (logits - 1 / (2 * self.beta)) ** 2 + else: + raise NotImplementedError(f"invalid loss type {self.args.loss_type}") + + # Compute XPO specific loss + xpo_losses = self.alpha * model_logprobs_ref_data_sum + + # Total loss + loss = (dpo_losses + xpo_losses).mean() + + return loss, dpo_losses, xpo_losses + + def _log_statistics( + self, + model_data, + ref_data, + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + dpo_losses, + xpo_losses, + context_length, + model_scores=None, + ref_scores=None, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather_for_metrics(tensor).mean().item() + + # Log losses + self.stats["loss/dpo"].append(gather_mean(dpo_losses)) + self.stats["loss/xpo"].append(gather_mean(xpo_losses)) + + # Log scores + if self.reward_model is not None: + self.stats["objective/model_scores"].append(gather_mean(model_scores)) + self.stats["objective/ref_scores"].append(gather_mean(ref_scores)) + self.stats["objective/scores_margin"].append(gather_mean(model_scores - ref_scores)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + model_logprobs_ref_data_sum = model_logprobs_ref_data.sum(1) + ref_logprobs_ref_data_sum = ref_logprobs_ref_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + chosen_model_logprobs = torch.where(chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + chosen_ref_logprobs = torch.where(chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + chosen_log_ratios = chosen_model_logprobs - chosen_ref_logprobs + + rejected_model_logprobs = torch.where(~chosen_mask, model_logprobs_model_data_sum, model_logprobs_ref_data_sum) + rejected_ref_logprobs = torch.where(~chosen_mask, ref_logprobs_model_data_sum, ref_logprobs_ref_data_sum) + rejected_log_ratios = rejected_model_logprobs - rejected_ref_logprobs + + self.stats["logps/chosen"].append(gather_mean(chosen_model_logprobs.mean() + chosen_ref_logprobs.mean())) + self.stats["logps/rejected"].append(gather_mean(rejected_model_logprobs.mean() + rejected_ref_logprobs.mean())) + + # Log rewards + # Compute various statistics + chosen_rewards = chosen_log_ratios * self.beta + rejected_rewards = rejected_log_ratios * self.beta + self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean())) + self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean())) + + # Calculate KL divergence for model and ref data + kl_model_data = model_logprobs_model_data - ref_logprobs_model_data + kl_ref_data = model_logprobs_ref_data - ref_logprobs_ref_data + mean_kl = (kl_model_data.sum(1) + kl_ref_data.sum(1)).mean() / 2 + self.stats["objective/kl"].append(gather_mean(mean_kl)) + + # Calculate entropy for model and ref data + entropy_model_data = -model_logprobs_model_data.sum(1) + entropy_ref_data = -model_logprobs_ref_data.sum(1) + mean_entropy = (entropy_model_data.mean() + entropy_ref_data.mean()) / 2 + self.stats["objective/entropy"].append(gather_mean(mean_entropy)) + + # Calculate margins + margin = chosen_rewards - rejected_rewards + self.stats["rewards/margins"].append(gather_mean(margin.mean())) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy.mean())) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + ref_eos = (ref_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float())) + + # Log alpha and beta + self.stats["alpha"].append(self.alpha) + self.stats["beta"].append(self.beta) + + def training_step( + self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None + ) -> torch.Tensor: + model.train() + + # Apply chat template and tokenize the input + batch_size = len(next(iter(inputs.values()))) + prompts = inputs["prompt"] + inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)] + inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs] + inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs] + inputs = self.data_collator(inputs) + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + "raw": prompts, + } + del inputs + + # Sample completions from both the model and the reference model + model_output, ref_output = self._generate_completions(prompts, model) + + # Process model completions + model_data, ref_data = self._process_completions(model_output, ref_output, prompts) + + # Compute rewards + if self.reward_model is not None: + model_scores, ref_scores = self._compute_rewards(model_data, ref_data, context_length) + chosen_mask = model_scores >= ref_scores + else: + model_scores, ref_scores = None, None + chosen_mask = self._compute_judge(model_data, ref_data, context_length) + + # Compute logprobs + model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = ( + self._compute_logprobs(model, model_data, ref_data, context_length) + ) + + # Compute loss + loss, dpo_losses, xpo_losses = self._compute_losses( + model_logprobs_model_data, + model_logprobs_ref_data, + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + ) + + # Log everything + self._log_statistics( + model_data, + ref_data, + model_logprobs_model_data.detach(), + model_logprobs_ref_data.detach(), + ref_logprobs_ref_data, + ref_logprobs_model_data, + chosen_mask, + dpo_losses.detach(), + xpo_losses.detach(), + context_length, + model_scores, + ref_scores, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, list[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str` or `None`, *optional*, defaults to `None`): + Name of the model. + dataset_name (`str` or `None`, *optional*, defaults to `None`): + Name of the dataset used for training. + tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + tags = tags or set() + if isinstance(tags, str): + tags = {tags} + + if hasattr(self.model.config, "unsloth_version"): + tags.add("unsloth") + + tags.update(self._tag_names) + + citation = textwrap.dedent("""\ + @article{jung2024binary, + title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}}, + author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin}, + year = 2024, + eprint = {arXiv:2405.21046} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + comet_url=get_comet_experiment_url(), + trainer_name="XPO", + trainer_citation=citation, + paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF", + paper_id="2405.21046", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000000000000000000000000000000000000..b25c76f930b2dd2e431d973ba2a9b28c9d7cf4d6 --- /dev/null +++ b/uv.lock @@ -0,0 +1,1194 @@ +version = 1 +revision = 2 +requires-python = ">=3.13" + +[[package]] +name = "accelerate" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/33/47bbd507e3a851d33d19ce7b2141c5ea3689bfae91ba168044d7db24b0e9/accelerate-1.7.0.tar.gz", hash = "sha256:e8a2a5503d6237b9eee73cc8d36cf543f9c2d8dd2c6713450b322f5e6d53a610", size = 376026, upload-time = "2025-05-15T10:00:52.117Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/bb/be8146c196ad6e4dec78385d91e92591f8a433576c4e04c342a636fcd811/accelerate-1.7.0-py3-none-any.whl", hash = "sha256:cf57165cca28769c6cf2650812371c81b18e05743dfa3c748524b1bb4f2b272f", size = 362095, upload-time = "2025-05-15T10:00:49.914Z" }, +] + +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/30/f84a107a9c4331c14b2b586036f40965c128aa4fee4dda5d3d51cb14ad54/aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558", size = 22760, upload-time = "2025-03-12T01:42:48.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/15/5bf3b99495fb160b63f95972b81750f18f7f4e02ad051373b669d17d44f2/aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8", size = 15265, upload-time = "2025-03-12T01:42:47.083Z" }, +] + +[[package]] +name = "aiohttp" +version = "3.12.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohappyeyeballs" }, + { name = "aiosignal" }, + { name = "attrs" }, + { name = "frozenlist" }, + { name = "multidict" }, + { name = "propcache" }, + { name = "yarl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/6e/ab88e7cb2a4058bed2f7870276454f85a7c56cd6da79349eb314fc7bbcaa/aiohttp-3.12.13.tar.gz", hash = "sha256:47e2da578528264a12e4e3dd8dd72a7289e5f812758fe086473fab037a10fcce", size = 7819160, upload-time = "2025-06-14T15:15:41.354Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/0f/db19abdf2d86aa1deec3c1e0e5ea46a587b97c07a16516b6438428b3a3f8/aiohttp-3.12.13-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d4a18e61f271127465bdb0e8ff36e8f02ac4a32a80d8927aa52371e93cd87938", size = 694910, upload-time = "2025-06-14T15:14:30.604Z" }, + { url = "https://files.pythonhosted.org/packages/d5/81/0ab551e1b5d7f1339e2d6eb482456ccbe9025605b28eed2b1c0203aaaade/aiohttp-3.12.13-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:532542cb48691179455fab429cdb0d558b5e5290b033b87478f2aa6af5d20ace", size = 472566, upload-time = "2025-06-14T15:14:32.275Z" }, + { url = "https://files.pythonhosted.org/packages/34/3f/6b7d336663337672d29b1f82d1f252ec1a040fe2d548f709d3f90fa2218a/aiohttp-3.12.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d7eea18b52f23c050ae9db5d01f3d264ab08f09e7356d6f68e3f3ac2de9dfabb", size = 464856, upload-time = "2025-06-14T15:14:34.132Z" }, + { url = "https://files.pythonhosted.org/packages/26/7f/32ca0f170496aa2ab9b812630fac0c2372c531b797e1deb3deb4cea904bd/aiohttp-3.12.13-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad7c8e5c25f2a26842a7c239de3f7b6bfb92304593ef997c04ac49fb703ff4d7", size = 1703683, upload-time = "2025-06-14T15:14:36.034Z" }, + { url = "https://files.pythonhosted.org/packages/ec/53/d5513624b33a811c0abea8461e30a732294112318276ce3dbf047dbd9d8b/aiohttp-3.12.13-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6af355b483e3fe9d7336d84539fef460120c2f6e50e06c658fe2907c69262d6b", size = 1684946, upload-time = "2025-06-14T15:14:38Z" }, + { url = "https://files.pythonhosted.org/packages/37/72/4c237dd127827b0247dc138d3ebd49c2ded6114c6991bbe969058575f25f/aiohttp-3.12.13-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a95cf9f097498f35c88e3609f55bb47b28a5ef67f6888f4390b3d73e2bac6177", size = 1737017, upload-time = "2025-06-14T15:14:39.951Z" }, + { url = "https://files.pythonhosted.org/packages/0d/67/8a7eb3afa01e9d0acc26e1ef847c1a9111f8b42b82955fcd9faeb84edeb4/aiohttp-3.12.13-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8ed8c38a1c584fe99a475a8f60eefc0b682ea413a84c6ce769bb19a7ff1c5ef", size = 1786390, upload-time = "2025-06-14T15:14:42.151Z" }, + { url = "https://files.pythonhosted.org/packages/48/19/0377df97dd0176ad23cd8cad4fd4232cfeadcec6c1b7f036315305c98e3f/aiohttp-3.12.13-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0b9170d5d800126b5bc89d3053a2363406d6e327afb6afaeda2d19ee8bb103", size = 1708719, upload-time = "2025-06-14T15:14:44.039Z" }, + { url = "https://files.pythonhosted.org/packages/61/97/ade1982a5c642b45f3622255173e40c3eed289c169f89d00eeac29a89906/aiohttp-3.12.13-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:372feeace612ef8eb41f05ae014a92121a512bd5067db8f25101dd88a8db11da", size = 1622424, upload-time = "2025-06-14T15:14:45.945Z" }, + { url = "https://files.pythonhosted.org/packages/99/ab/00ad3eea004e1d07ccc406e44cfe2b8da5acb72f8c66aeeb11a096798868/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a946d3702f7965d81f7af7ea8fb03bb33fe53d311df48a46eeca17e9e0beed2d", size = 1675447, upload-time = "2025-06-14T15:14:47.911Z" }, + { url = "https://files.pythonhosted.org/packages/3f/fe/74e5ce8b2ccaba445fe0087abc201bfd7259431d92ae608f684fcac5d143/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:a0c4725fae86555bbb1d4082129e21de7264f4ab14baf735278c974785cd2041", size = 1707110, upload-time = "2025-06-14T15:14:50.334Z" }, + { url = "https://files.pythonhosted.org/packages/ef/c4/39af17807f694f7a267bd8ab1fbacf16ad66740862192a6c8abac2bff813/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:9b28ea2f708234f0a5c44eb6c7d9eb63a148ce3252ba0140d050b091b6e842d1", size = 1649706, upload-time = "2025-06-14T15:14:52.378Z" }, + { url = "https://files.pythonhosted.org/packages/38/e8/f5a0a5f44f19f171d8477059aa5f28a158d7d57fe1a46c553e231f698435/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d4f5becd2a5791829f79608c6f3dc745388162376f310eb9c142c985f9441cc1", size = 1725839, upload-time = "2025-06-14T15:14:54.617Z" }, + { url = "https://files.pythonhosted.org/packages/fd/ac/81acc594c7f529ef4419d3866913f628cd4fa9cab17f7bf410a5c3c04c53/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:60f2ce6b944e97649051d5f5cc0f439360690b73909230e107fd45a359d3e911", size = 1759311, upload-time = "2025-06-14T15:14:56.597Z" }, + { url = "https://files.pythonhosted.org/packages/38/0d/aabe636bd25c6ab7b18825e5a97d40024da75152bec39aa6ac8b7a677630/aiohttp-3.12.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:69fc1909857401b67bf599c793f2183fbc4804717388b0b888f27f9929aa41f3", size = 1708202, upload-time = "2025-06-14T15:14:58.598Z" }, + { url = "https://files.pythonhosted.org/packages/1f/ab/561ef2d8a223261683fb95a6283ad0d36cb66c87503f3a7dde7afe208bb2/aiohttp-3.12.13-cp313-cp313-win32.whl", hash = "sha256:7d7e68787a2046b0e44ba5587aa723ce05d711e3a3665b6b7545328ac8e3c0dd", size = 420794, upload-time = "2025-06-14T15:15:00.939Z" }, + { url = "https://files.pythonhosted.org/packages/9d/47/b11d0089875a23bff0abd3edb5516bcd454db3fefab8604f5e4b07bd6210/aiohttp-3.12.13-cp313-cp313-win_amd64.whl", hash = "sha256:5a178390ca90419bfd41419a809688c368e63c86bd725e1186dd97f6b89c2706", size = 446735, upload-time = "2025-06-14T15:15:02.858Z" }, +] + +[[package]] +name = "aiosignal" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "frozenlist" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/b5/6d55e80f6d8a08ce22b982eafa278d823b541c925f11ee774b0b9c43473d/aiosignal-1.3.2.tar.gz", hash = "sha256:a8c255c66fafb1e499c9351d0bf32ff2d8a0321595ebac3b93713656d2436f54", size = 19424, upload-time = "2024-12-13T17:10:40.86Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "attrs" +version = "25.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/b0/1367933a8532ee6ff8d63537de4f1177af4bff9f3e829baf7331f595bb24/attrs-25.3.0.tar.gz", hash = "sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b", size = 812032, upload-time = "2025-03-13T11:10:22.779Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, +] + +[[package]] +name = "certifi" +version = "2025.6.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/73/f7/f14b46d4bcd21092d7d3ccef689615220d8a08fb25e564b65d20738e672e/certifi-2025.6.15.tar.gz", hash = "sha256:d747aa5a8b9bbbb1bb8c22bb13e22bd1f18e9796defa16bab421f7f7a317323b", size = 158753, upload-time = "2025-06-15T02:45:51.329Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650, upload-time = "2025-06-15T02:45:49.977Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/33/89c2ced2b67d1c2a61c19c6751aa8902d46ce3dacb23600a283619f5a12d/charset_normalizer-3.4.2.tar.gz", hash = "sha256:5baececa9ecba31eff645232d59845c07aa030f0c81ee70184a90d35099a0e63", size = 126367, upload-time = "2025-05-02T08:34:42.01Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/12/a93df3366ed32db1d907d7593a94f1fe6293903e3e92967bebd6950ed12c/charset_normalizer-3.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:926ca93accd5d36ccdabd803392ddc3e03e6d4cd1cf17deff3b989ab8e9dbcf0", size = 199622, upload-time = "2025-05-02T08:32:56.363Z" }, + { url = "https://files.pythonhosted.org/packages/04/93/bf204e6f344c39d9937d3c13c8cd5bbfc266472e51fc8c07cb7f64fcd2de/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eba9904b0f38a143592d9fc0e19e2df0fa2e41c3c3745554761c5f6447eedabf", size = 143435, upload-time = "2025-05-02T08:32:58.551Z" }, + { url = "https://files.pythonhosted.org/packages/22/2a/ea8a2095b0bafa6c5b5a55ffdc2f924455233ee7b91c69b7edfcc9e02284/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3fddb7e2c84ac87ac3a947cb4e66d143ca5863ef48e4a5ecb83bd48619e4634e", size = 153653, upload-time = "2025-05-02T08:33:00.342Z" }, + { url = "https://files.pythonhosted.org/packages/b6/57/1b090ff183d13cef485dfbe272e2fe57622a76694061353c59da52c9a659/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98f862da73774290f251b9df8d11161b6cf25b599a66baf087c1ffe340e9bfd1", size = 146231, upload-time = "2025-05-02T08:33:02.081Z" }, + { url = "https://files.pythonhosted.org/packages/e2/28/ffc026b26f441fc67bd21ab7f03b313ab3fe46714a14b516f931abe1a2d8/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c9379d65defcab82d07b2a9dfbfc2e95bc8fe0ebb1b176a3190230a3ef0e07c", size = 148243, upload-time = "2025-05-02T08:33:04.063Z" }, + { url = "https://files.pythonhosted.org/packages/c0/0f/9abe9bd191629c33e69e47c6ef45ef99773320e9ad8e9cb08b8ab4a8d4cb/charset_normalizer-3.4.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e635b87f01ebc977342e2697d05b56632f5f879a4f15955dfe8cef2448b51691", size = 150442, upload-time = "2025-05-02T08:33:06.418Z" }, + { url = "https://files.pythonhosted.org/packages/67/7c/a123bbcedca91d5916c056407f89a7f5e8fdfce12ba825d7d6b9954a1a3c/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1c95a1e2902a8b722868587c0e1184ad5c55631de5afc0eb96bc4b0d738092c0", size = 145147, upload-time = "2025-05-02T08:33:08.183Z" }, + { url = "https://files.pythonhosted.org/packages/ec/fe/1ac556fa4899d967b83e9893788e86b6af4d83e4726511eaaad035e36595/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:ef8de666d6179b009dce7bcb2ad4c4a779f113f12caf8dc77f0162c29d20490b", size = 153057, upload-time = "2025-05-02T08:33:09.986Z" }, + { url = "https://files.pythonhosted.org/packages/2b/ff/acfc0b0a70b19e3e54febdd5301a98b72fa07635e56f24f60502e954c461/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:32fc0341d72e0f73f80acb0a2c94216bd704f4f0bce10aedea38f30502b271ff", size = 156454, upload-time = "2025-05-02T08:33:11.814Z" }, + { url = "https://files.pythonhosted.org/packages/92/08/95b458ce9c740d0645feb0e96cea1f5ec946ea9c580a94adfe0b617f3573/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:289200a18fa698949d2b39c671c2cc7a24d44096784e76614899a7ccf2574b7b", size = 154174, upload-time = "2025-05-02T08:33:13.707Z" }, + { url = "https://files.pythonhosted.org/packages/78/be/8392efc43487ac051eee6c36d5fbd63032d78f7728cb37aebcc98191f1ff/charset_normalizer-3.4.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a476b06fbcf359ad25d34a057b7219281286ae2477cc5ff5e3f70a246971148", size = 149166, upload-time = "2025-05-02T08:33:15.458Z" }, + { url = "https://files.pythonhosted.org/packages/44/96/392abd49b094d30b91d9fbda6a69519e95802250b777841cf3bda8fe136c/charset_normalizer-3.4.2-cp313-cp313-win32.whl", hash = "sha256:aaeeb6a479c7667fbe1099af9617c83aaca22182d6cf8c53966491a0f1b7ffb7", size = 98064, upload-time = "2025-05-02T08:33:17.06Z" }, + { url = "https://files.pythonhosted.org/packages/e9/b0/0200da600134e001d91851ddc797809e2fe0ea72de90e09bec5a2fbdaccb/charset_normalizer-3.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:aa6af9e7d59f9c12b33ae4e9450619cf2488e2bbe9b44030905877f0b2324980", size = 105641, upload-time = "2025-05-02T08:33:18.753Z" }, + { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "datasets" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, + { name = "filelock" }, + { name = "fsspec", extra = ["http"] }, + { name = "huggingface-hub" }, + { name = "multiprocess" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1a/89/d3d6fef58a488f8569c82fd293ab7cbd4250244d67f425dcae64c63800ea/datasets-3.6.0.tar.gz", hash = "sha256:1b2bf43b19776e2787e181cfd329cb0ca1a358ea014780c3581e0f276375e041", size = 569336, upload-time = "2025-05-07T15:15:02.659Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/34/a08b0ee99715eaba118cbe19a71f7b5e2425c2718ef96007c325944a1152/datasets-3.6.0-py3-none-any.whl", hash = "sha256:25000c4a2c0873a710df127d08a202a06eab7bf42441a6bc278b499c2f72cd1b", size = 491546, upload-time = "2025-05-07T15:14:59.742Z" }, +] + +[[package]] +name = "deepspeed" +version = "0.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "hjson" }, + { name = "msgpack" }, + { name = "ninja" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, + { name = "pydantic" }, + { name = "torch" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/38/10/a7f63e086c1e1c12e290c98363c748ef5ddd6313fde739d2aeccd5ed0cd4/deepspeed-0.17.1.tar.gz", hash = "sha256:6d6e21796982b9e024f489e1c211666cc6c0be6e344751368610b9d2da285d6e", size = 1547985, upload-time = "2025-06-09T22:53:11.543Z" } + +[[package]] +name = "dill" +version = "0.3.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, +] + +[[package]] +name = "einops" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805, upload-time = "2025-02-09T03:17:00.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, +] + +[[package]] +name = "filelock" +version = "3.18.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/10/c23352565a6544bdc5353e0b15fc1c563352101f30e24bf500207a54df9a/filelock-3.18.0.tar.gz", hash = "sha256:adbc88eabb99d2fec8c9c1b229b171f18afa655400173ddc653d5d01501fb9f2", size = 18075, upload-time = "2025-03-14T07:11:40.47Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/36/2a115987e2d8c300a974597416d9de88f2444426de9571f4b59b2cca3acc/filelock-3.18.0-py3-none-any.whl", hash = "sha256:c401f4f8377c4464e6db25fff06205fd89bdd83b65eb0488ed1b160f780e21de", size = 16215, upload-time = "2025-03-14T07:11:39.145Z" }, +] + +[[package]] +name = "frozenlist" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/b1/b64018016eeb087db503b038296fd782586432b9c077fc5c7839e9cb6ef6/frozenlist-1.7.0.tar.gz", hash = "sha256:2e310d81923c2437ea8670467121cc3e9b0f76d3043cc1d2331d56c7fb7a3a8f", size = 45078, upload-time = "2025-06-09T23:02:35.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/90/6b2cebdabdbd50367273c20ff6b57a3dfa89bd0762de02c3a1eb42cb6462/frozenlist-1.7.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee80eeda5e2a4e660651370ebffd1286542b67e268aa1ac8d6dbe973120ef7ee", size = 79791, upload-time = "2025-06-09T23:01:09.368Z" }, + { url = "https://files.pythonhosted.org/packages/83/2e/5b70b6a3325363293fe5fc3ae74cdcbc3e996c2a11dde2fd9f1fb0776d19/frozenlist-1.7.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d1a81c85417b914139e3a9b995d4a1c84559afc839a93cf2cb7f15e6e5f6ed2d", size = 47165, upload-time = "2025-06-09T23:01:10.653Z" }, + { url = "https://files.pythonhosted.org/packages/f4/25/a0895c99270ca6966110f4ad98e87e5662eab416a17e7fd53c364bf8b954/frozenlist-1.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cbb65198a9132ebc334f237d7b0df163e4de83fb4f2bdfe46c1e654bdb0c5d43", size = 45881, upload-time = "2025-06-09T23:01:12.296Z" }, + { url = "https://files.pythonhosted.org/packages/19/7c/71bb0bbe0832793c601fff68cd0cf6143753d0c667f9aec93d3c323f4b55/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dab46c723eeb2c255a64f9dc05b8dd601fde66d6b19cdb82b2e09cc6ff8d8b5d", size = 232409, upload-time = "2025-06-09T23:01:13.641Z" }, + { url = "https://files.pythonhosted.org/packages/c0/45/ed2798718910fe6eb3ba574082aaceff4528e6323f9a8570be0f7028d8e9/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6aeac207a759d0dedd2e40745575ae32ab30926ff4fa49b1635def65806fddee", size = 225132, upload-time = "2025-06-09T23:01:15.264Z" }, + { url = "https://files.pythonhosted.org/packages/ba/e2/8417ae0f8eacb1d071d4950f32f229aa6bf68ab69aab797b72a07ea68d4f/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bd8c4e58ad14b4fa7802b8be49d47993182fdd4023393899632c88fd8cd994eb", size = 237638, upload-time = "2025-06-09T23:01:16.752Z" }, + { url = "https://files.pythonhosted.org/packages/f8/b7/2ace5450ce85f2af05a871b8c8719b341294775a0a6c5585d5e6170f2ce7/frozenlist-1.7.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04fb24d104f425da3540ed83cbfc31388a586a7696142004c577fa61c6298c3f", size = 233539, upload-time = "2025-06-09T23:01:18.202Z" }, + { url = "https://files.pythonhosted.org/packages/46/b9/6989292c5539553dba63f3c83dc4598186ab2888f67c0dc1d917e6887db6/frozenlist-1.7.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6a5c505156368e4ea6b53b5ac23c92d7edc864537ff911d2fb24c140bb175e60", size = 215646, upload-time = "2025-06-09T23:01:19.649Z" }, + { url = "https://files.pythonhosted.org/packages/72/31/bc8c5c99c7818293458fe745dab4fd5730ff49697ccc82b554eb69f16a24/frozenlist-1.7.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bd7eb96a675f18aa5c553eb7ddc24a43c8c18f22e1f9925528128c052cdbe00", size = 232233, upload-time = "2025-06-09T23:01:21.175Z" }, + { url = "https://files.pythonhosted.org/packages/59/52/460db4d7ba0811b9ccb85af996019f5d70831f2f5f255f7cc61f86199795/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:05579bf020096fe05a764f1f84cd104a12f78eaab68842d036772dc6d4870b4b", size = 227996, upload-time = "2025-06-09T23:01:23.098Z" }, + { url = "https://files.pythonhosted.org/packages/ba/c9/f4b39e904c03927b7ecf891804fd3b4df3db29b9e487c6418e37988d6e9d/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:376b6222d114e97eeec13d46c486facd41d4f43bab626b7c3f6a8b4e81a5192c", size = 242280, upload-time = "2025-06-09T23:01:24.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/33/3f8d6ced42f162d743e3517781566b8481322be321b486d9d262adf70bfb/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0aa7e176ebe115379b5b1c95b4096fb1c17cce0847402e227e712c27bdb5a949", size = 217717, upload-time = "2025-06-09T23:01:26.28Z" }, + { url = "https://files.pythonhosted.org/packages/3e/e8/ad683e75da6ccef50d0ab0c2b2324b32f84fc88ceee778ed79b8e2d2fe2e/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3fbba20e662b9c2130dc771e332a99eff5da078b2b2648153a40669a6d0e36ca", size = 236644, upload-time = "2025-06-09T23:01:27.887Z" }, + { url = "https://files.pythonhosted.org/packages/b2/14/8d19ccdd3799310722195a72ac94ddc677541fb4bef4091d8e7775752360/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:f3f4410a0a601d349dd406b5713fec59b4cee7e71678d5b17edda7f4655a940b", size = 238879, upload-time = "2025-06-09T23:01:29.524Z" }, + { url = "https://files.pythonhosted.org/packages/ce/13/c12bf657494c2fd1079a48b2db49fa4196325909249a52d8f09bc9123fd7/frozenlist-1.7.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e2cdfaaec6a2f9327bf43c933c0319a7c429058e8537c508964a133dffee412e", size = 232502, upload-time = "2025-06-09T23:01:31.287Z" }, + { url = "https://files.pythonhosted.org/packages/d7/8b/e7f9dfde869825489382bc0d512c15e96d3964180c9499efcec72e85db7e/frozenlist-1.7.0-cp313-cp313-win32.whl", hash = "sha256:5fc4df05a6591c7768459caba1b342d9ec23fa16195e744939ba5914596ae3e1", size = 39169, upload-time = "2025-06-09T23:01:35.503Z" }, + { url = "https://files.pythonhosted.org/packages/35/89/a487a98d94205d85745080a37860ff5744b9820a2c9acbcdd9440bfddf98/frozenlist-1.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:52109052b9791a3e6b5d1b65f4b909703984b770694d3eb64fad124c835d7cba", size = 43219, upload-time = "2025-06-09T23:01:36.784Z" }, + { url = "https://files.pythonhosted.org/packages/56/d5/5c4cf2319a49eddd9dd7145e66c4866bdc6f3dbc67ca3d59685149c11e0d/frozenlist-1.7.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a6f86e4193bb0e235ef6ce3dde5cbabed887e0b11f516ce8a0f4d3b33078ec2d", size = 84345, upload-time = "2025-06-09T23:01:38.295Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/ec2c1e1dc16b85bc9d526009961953df9cec8481b6886debb36ec9107799/frozenlist-1.7.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:82d664628865abeb32d90ae497fb93df398a69bb3434463d172b80fc25b0dd7d", size = 48880, upload-time = "2025-06-09T23:01:39.887Z" }, + { url = "https://files.pythonhosted.org/packages/69/86/f9596807b03de126e11e7d42ac91e3d0b19a6599c714a1989a4e85eeefc4/frozenlist-1.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:912a7e8375a1c9a68325a902f3953191b7b292aa3c3fb0d71a216221deca460b", size = 48498, upload-time = "2025-06-09T23:01:41.318Z" }, + { url = "https://files.pythonhosted.org/packages/5e/cb/df6de220f5036001005f2d726b789b2c0b65f2363b104bbc16f5be8084f8/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9537c2777167488d539bc5de2ad262efc44388230e5118868e172dd4a552b146", size = 292296, upload-time = "2025-06-09T23:01:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/83/1f/de84c642f17c8f851a2905cee2dae401e5e0daca9b5ef121e120e19aa825/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:f34560fb1b4c3e30ba35fa9a13894ba39e5acfc5f60f57d8accde65f46cc5e74", size = 273103, upload-time = "2025-06-09T23:01:44.166Z" }, + { url = "https://files.pythonhosted.org/packages/88/3c/c840bfa474ba3fa13c772b93070893c6e9d5c0350885760376cbe3b6c1b3/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:acd03d224b0175f5a850edc104ac19040d35419eddad04e7cf2d5986d98427f1", size = 292869, upload-time = "2025-06-09T23:01:45.681Z" }, + { url = "https://files.pythonhosted.org/packages/a6/1c/3efa6e7d5a39a1d5ef0abeb51c48fb657765794a46cf124e5aca2c7a592c/frozenlist-1.7.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f2038310bc582f3d6a09b3816ab01737d60bf7b1ec70f5356b09e84fb7408ab1", size = 291467, upload-time = "2025-06-09T23:01:47.234Z" }, + { url = "https://files.pythonhosted.org/packages/4f/00/d5c5e09d4922c395e2f2f6b79b9a20dab4b67daaf78ab92e7729341f61f6/frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8c05e4c8e5f36e5e088caa1bf78a687528f83c043706640a92cb76cd6999384", size = 266028, upload-time = "2025-06-09T23:01:48.819Z" }, + { url = "https://files.pythonhosted.org/packages/4e/27/72765be905619dfde25a7f33813ac0341eb6b076abede17a2e3fbfade0cb/frozenlist-1.7.0-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:765bb588c86e47d0b68f23c1bee323d4b703218037765dcf3f25c838c6fecceb", size = 284294, upload-time = "2025-06-09T23:01:50.394Z" }, + { url = "https://files.pythonhosted.org/packages/88/67/c94103a23001b17808eb7dd1200c156bb69fb68e63fcf0693dde4cd6228c/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:32dc2e08c67d86d0969714dd484fd60ff08ff81d1a1e40a77dd34a387e6ebc0c", size = 281898, upload-time = "2025-06-09T23:01:52.234Z" }, + { url = "https://files.pythonhosted.org/packages/42/34/a3e2c00c00f9e2a9db5653bca3fec306349e71aff14ae45ecc6d0951dd24/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:c0303e597eb5a5321b4de9c68e9845ac8f290d2ab3f3e2c864437d3c5a30cd65", size = 290465, upload-time = "2025-06-09T23:01:53.788Z" }, + { url = "https://files.pythonhosted.org/packages/bb/73/f89b7fbce8b0b0c095d82b008afd0590f71ccb3dee6eee41791cf8cd25fd/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:a47f2abb4e29b3a8d0b530f7c3598badc6b134562b1a5caee867f7c62fee51e3", size = 266385, upload-time = "2025-06-09T23:01:55.769Z" }, + { url = "https://files.pythonhosted.org/packages/cd/45/e365fdb554159462ca12df54bc59bfa7a9a273ecc21e99e72e597564d1ae/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:3d688126c242a6fabbd92e02633414d40f50bb6002fa4cf995a1d18051525657", size = 288771, upload-time = "2025-06-09T23:01:57.4Z" }, + { url = "https://files.pythonhosted.org/packages/00/11/47b6117002a0e904f004d70ec5194fe9144f117c33c851e3d51c765962d0/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:4e7e9652b3d367c7bd449a727dc79d5043f48b88d0cbfd4f9f1060cf2b414104", size = 288206, upload-time = "2025-06-09T23:01:58.936Z" }, + { url = "https://files.pythonhosted.org/packages/40/37/5f9f3c3fd7f7746082ec67bcdc204db72dad081f4f83a503d33220a92973/frozenlist-1.7.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:1a85e345b4c43db8b842cab1feb41be5cc0b10a1830e6295b69d7310f99becaf", size = 282620, upload-time = "2025-06-09T23:02:00.493Z" }, + { url = "https://files.pythonhosted.org/packages/0b/31/8fbc5af2d183bff20f21aa743b4088eac4445d2bb1cdece449ae80e4e2d1/frozenlist-1.7.0-cp313-cp313t-win32.whl", hash = "sha256:3a14027124ddb70dfcee5148979998066897e79f89f64b13328595c4bdf77c81", size = 43059, upload-time = "2025-06-09T23:02:02.072Z" }, + { url = "https://files.pythonhosted.org/packages/bb/ed/41956f52105b8dbc26e457c5705340c67c8cc2b79f394b79bffc09d0e938/frozenlist-1.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:3bf8010d71d4507775f658e9823210b7427be36625b387221642725b515dcf3e", size = 47516, upload-time = "2025-06-09T23:02:03.779Z" }, + { url = "https://files.pythonhosted.org/packages/ee/45/b82e3c16be2182bff01179db177fe144d58b5dc787a7d4492c6ed8b9317f/frozenlist-1.7.0-py3-none-any.whl", hash = "sha256:9a5af342e34f7e97caf8c995864c7a396418ae2859cc6fdf1b1073020d516a7e", size = 13106, upload-time = "2025-06-09T23:02:34.204Z" }, +] + +[[package]] +name = "fsspec" +version = "2025.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f4/5721faf47b8c499e776bc34c6a8fc17efdf7fdef0b00f398128bc5dcb4ac/fsspec-2025.3.0.tar.gz", hash = "sha256:a935fd1ea872591f2b5148907d103488fc523295e6c64b835cfad8c3eca44972", size = 298491, upload-time = "2025-03-07T21:47:56.461Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/53/eb690efa8513166adef3e0669afd31e95ffde69fb3c52ec2ac7223ed6018/fsspec-2025.3.0-py3-none-any.whl", hash = "sha256:efb87af3efa9103f94ca91a7f8cb7a4df91af9f74fc106c9c7ea0efd7277c1b3", size = 193615, upload-time = "2025-03-07T21:47:54.809Z" }, +] + +[package.optional-dependencies] +http = [ + { name = "aiohttp" }, +] + +[[package]] +name = "hf-xet" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/dc/dc091aeeb671e71cbec30e84963f9c0202c17337b24b0a800e7d205543e8/hf_xet-1.1.3.tar.gz", hash = "sha256:a5f09b1dd24e6ff6bcedb4b0ddab2d81824098bb002cf8b4ffa780545fa348c3", size = 488127, upload-time = "2025-06-04T00:47:27.456Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/1f/bc01a4c0894973adebbcd4aa338a06815c76333ebb3921d94dcbd40dae6a/hf_xet-1.1.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c3b508b5f583a75641aebf732853deb058953370ce8184f5dabc49f803b0819b", size = 2256929, upload-time = "2025-06-04T00:47:21.206Z" }, + { url = "https://files.pythonhosted.org/packages/78/07/6ef50851b5c6b45b77a6e018fa299c69a2db3b8bbd0d5af594c0238b1ceb/hf_xet-1.1.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b788a61977fbe6b5186e66239e2a329a3f0b7e7ff50dad38984c0c74f44aeca1", size = 2153719, upload-time = "2025-06-04T00:47:19.302Z" }, + { url = "https://files.pythonhosted.org/packages/52/48/e929e6e3db6e4758c2adf0f2ca2c59287f1b76229d8bdc1a4c9cfc05212e/hf_xet-1.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd2da210856444a34aad8ada2fc12f70dabed7cc20f37e90754d1d9b43bc0534", size = 4820519, upload-time = "2025-06-04T00:47:17.244Z" }, + { url = "https://files.pythonhosted.org/packages/28/2e/03f89c5014a5aafaa9b150655f811798a317036646623bdaace25f485ae8/hf_xet-1.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8203f52827e3df65981984936654a5b390566336956f65765a8aa58c362bb841", size = 4964121, upload-time = "2025-06-04T00:47:15.17Z" }, + { url = "https://files.pythonhosted.org/packages/47/8b/5cd399a92b47d98086f55fc72d69bc9ea5e5c6f27a9ed3e0cdd6be4e58a3/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:30c575a5306f8e6fda37edb866762140a435037365eba7a17ce7bd0bc0216a8b", size = 5283017, upload-time = "2025-06-04T00:47:23.239Z" }, + { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, + { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, +] + +[[package]] +name = "hjson" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/82/e5/0b56d723a76ca67abadbf7fb71609fb0ea7e6926e94fcca6c65a85b36a0e/hjson-3.1.0.tar.gz", hash = "sha256:55af475a27cf83a7969c808399d7bccdec8fb836a07ddbd574587593b9cdcf75", size = 40541, upload-time = "2022-08-13T02:53:01.919Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/7f/13cd798d180af4bf4c0ceddeefba2b864a63c71645abc0308b768d67bb81/hjson-3.1.0-py3-none-any.whl", hash = "sha256:65713cdcf13214fb554eb8b4ef803419733f4f5e551047c9b711098ab7186b89", size = 54018, upload-time = "2022-08-13T02:52:59.899Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "0.33.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/8a/1362d565fefabaa4185cf3ae842a98dbc5b35146f5694f7080f043a6952f/huggingface_hub-0.33.0.tar.gz", hash = "sha256:aa31f70d29439d00ff7a33837c03f1f9dd83971ce4e29ad664d63ffb17d3bb97", size = 426179, upload-time = "2025-06-11T17:08:07.913Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/fb/53587a89fbc00799e4179796f51b3ad713c5de6bb680b2becb6d37c94649/huggingface_hub-0.33.0-py3-none-any.whl", hash = "sha256:e8668875b40c68f9929150d99727d39e5ebb8a05a98e4191b908dc7ded9074b3", size = 514799, upload-time = "2025-06-11T17:08:05.757Z" }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b2/97/5d42485e71dfc078108a86d6de8fa46db44a1a9295e89c5d6d4a06e23a62/markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0", size = 20537, upload-time = "2024-10-18T15:21:54.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/0e/67eb10a7ecc77a0c2bbe2b0235765b98d164d81600746914bebada795e97/MarkupSafe-3.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd", size = 14274, upload-time = "2024-10-18T15:21:24.577Z" }, + { url = "https://files.pythonhosted.org/packages/2b/6d/9409f3684d3335375d04e5f05744dfe7e9f120062c9857df4ab490a1031a/MarkupSafe-3.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430", size = 12352, upload-time = "2024-10-18T15:21:25.382Z" }, + { url = "https://files.pythonhosted.org/packages/d2/f5/6eadfcd3885ea85fe2a7c128315cc1bb7241e1987443d78c8fe712d03091/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094", size = 24122, upload-time = "2024-10-18T15:21:26.199Z" }, + { url = "https://files.pythonhosted.org/packages/0c/91/96cf928db8236f1bfab6ce15ad070dfdd02ed88261c2afafd4b43575e9e9/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396", size = 23085, upload-time = "2024-10-18T15:21:27.029Z" }, + { url = "https://files.pythonhosted.org/packages/c2/cf/c9d56af24d56ea04daae7ac0940232d31d5a8354f2b457c6d856b2057d69/MarkupSafe-3.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79", size = 22978, upload-time = "2024-10-18T15:21:27.846Z" }, + { url = "https://files.pythonhosted.org/packages/2a/9f/8619835cd6a711d6272d62abb78c033bda638fdc54c4e7f4272cf1c0962b/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a", size = 24208, upload-time = "2024-10-18T15:21:28.744Z" }, + { url = "https://files.pythonhosted.org/packages/f9/bf/176950a1792b2cd2102b8ffeb5133e1ed984547b75db47c25a67d3359f77/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca", size = 23357, upload-time = "2024-10-18T15:21:29.545Z" }, + { url = "https://files.pythonhosted.org/packages/ce/4f/9a02c1d335caabe5c4efb90e1b6e8ee944aa245c1aaaab8e8a618987d816/MarkupSafe-3.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c", size = 23344, upload-time = "2024-10-18T15:21:30.366Z" }, + { url = "https://files.pythonhosted.org/packages/ee/55/c271b57db36f748f0e04a759ace9f8f759ccf22b4960c270c78a394f58be/MarkupSafe-3.0.2-cp313-cp313-win32.whl", hash = "sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1", size = 15101, upload-time = "2024-10-18T15:21:31.207Z" }, + { url = "https://files.pythonhosted.org/packages/29/88/07df22d2dd4df40aba9f3e402e6dc1b8ee86297dddbad4872bd5e7b0094f/MarkupSafe-3.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f", size = 15603, upload-time = "2024-10-18T15:21:32.032Z" }, + { url = "https://files.pythonhosted.org/packages/62/6a/8b89d24db2d32d433dffcd6a8779159da109842434f1dd2f6e71f32f738c/MarkupSafe-3.0.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c", size = 14510, upload-time = "2024-10-18T15:21:33.625Z" }, + { url = "https://files.pythonhosted.org/packages/7a/06/a10f955f70a2e5a9bf78d11a161029d278eeacbd35ef806c3fd17b13060d/MarkupSafe-3.0.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb", size = 12486, upload-time = "2024-10-18T15:21:34.611Z" }, + { url = "https://files.pythonhosted.org/packages/34/cf/65d4a571869a1a9078198ca28f39fba5fbb910f952f9dbc5220afff9f5e6/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c", size = 25480, upload-time = "2024-10-18T15:21:35.398Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e3/90e9651924c430b885468b56b3d597cabf6d72be4b24a0acd1fa0e12af67/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d", size = 23914, upload-time = "2024-10-18T15:21:36.231Z" }, + { url = "https://files.pythonhosted.org/packages/66/8c/6c7cf61f95d63bb866db39085150df1f2a5bd3335298f14a66b48e92659c/MarkupSafe-3.0.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe", size = 23796, upload-time = "2024-10-18T15:21:37.073Z" }, + { url = "https://files.pythonhosted.org/packages/bb/35/cbe9238ec3f47ac9a7c8b3df7a808e7cb50fe149dc7039f5f454b3fba218/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5", size = 25473, upload-time = "2024-10-18T15:21:37.932Z" }, + { url = "https://files.pythonhosted.org/packages/e6/32/7621a4382488aa283cc05e8984a9c219abad3bca087be9ec77e89939ded9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a", size = 24114, upload-time = "2024-10-18T15:21:39.799Z" }, + { url = "https://files.pythonhosted.org/packages/0d/80/0985960e4b89922cb5a0bac0ed39c5b96cbc1a536a99f30e8c220a996ed9/MarkupSafe-3.0.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9", size = 24098, upload-time = "2024-10-18T15:21:40.813Z" }, + { url = "https://files.pythonhosted.org/packages/82/78/fedb03c7d5380df2427038ec8d973587e90561b2d90cd472ce9254cf348b/MarkupSafe-3.0.2-cp313-cp313t-win32.whl", hash = "sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6", size = 15208, upload-time = "2024-10-18T15:21:41.814Z" }, + { url = "https://files.pythonhosted.org/packages/4f/65/6079a46068dfceaeabb5dcad6d674f5f5c61a6fa5673746f42a9f4c233b3/MarkupSafe-3.0.2-cp313-cp313t-win_amd64.whl", hash = "sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f", size = 15739, upload-time = "2024-10-18T15:21:42.784Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "msgpack" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/45/b1/ea4f68038a18c77c9467400d166d74c4ffa536f34761f7983a104357e614/msgpack-1.1.1.tar.gz", hash = "sha256:77b79ce34a2bdab2594f490c8e80dd62a02d650b91a75159a63ec413b8d104cd", size = 173555, upload-time = "2025-06-13T06:52:51.324Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/38/561f01cf3577430b59b340b51329803d3a5bf6a45864a55f4ef308ac11e3/msgpack-1.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3765afa6bd4832fc11c3749be4ba4b69a0e8d7b728f78e68120a157a4c5d41f0", size = 81677, upload-time = "2025-06-13T06:52:16.64Z" }, + { url = "https://files.pythonhosted.org/packages/09/48/54a89579ea36b6ae0ee001cba8c61f776451fad3c9306cd80f5b5c55be87/msgpack-1.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8ddb2bcfd1a8b9e431c8d6f4f7db0773084e107730ecf3472f1dfe9ad583f3d9", size = 78603, upload-time = "2025-06-13T06:52:17.843Z" }, + { url = "https://files.pythonhosted.org/packages/a0/60/daba2699b308e95ae792cdc2ef092a38eb5ee422f9d2fbd4101526d8a210/msgpack-1.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:196a736f0526a03653d829d7d4c5500a97eea3648aebfd4b6743875f28aa2af8", size = 420504, upload-time = "2025-06-13T06:52:18.982Z" }, + { url = "https://files.pythonhosted.org/packages/20/22/2ebae7ae43cd8f2debc35c631172ddf14e2a87ffcc04cf43ff9df9fff0d3/msgpack-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d592d06e3cc2f537ceeeb23d38799c6ad83255289bb84c2e5792e5a8dea268a", size = 423749, upload-time = "2025-06-13T06:52:20.211Z" }, + { url = "https://files.pythonhosted.org/packages/40/1b/54c08dd5452427e1179a40b4b607e37e2664bca1c790c60c442c8e972e47/msgpack-1.1.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4df2311b0ce24f06ba253fda361f938dfecd7b961576f9be3f3fbd60e87130ac", size = 404458, upload-time = "2025-06-13T06:52:21.429Z" }, + { url = "https://files.pythonhosted.org/packages/2e/60/6bb17e9ffb080616a51f09928fdd5cac1353c9becc6c4a8abd4e57269a16/msgpack-1.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e4141c5a32b5e37905b5940aacbc59739f036930367d7acce7a64e4dec1f5e0b", size = 405976, upload-time = "2025-06-13T06:52:22.995Z" }, + { url = "https://files.pythonhosted.org/packages/ee/97/88983e266572e8707c1f4b99c8fd04f9eb97b43f2db40e3172d87d8642db/msgpack-1.1.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b1ce7f41670c5a69e1389420436f41385b1aa2504c3b0c30620764b15dded2e7", size = 408607, upload-time = "2025-06-13T06:52:24.152Z" }, + { url = "https://files.pythonhosted.org/packages/bc/66/36c78af2efaffcc15a5a61ae0df53a1d025f2680122e2a9eb8442fed3ae4/msgpack-1.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4147151acabb9caed4e474c3344181e91ff7a388b888f1e19ea04f7e73dc7ad5", size = 424172, upload-time = "2025-06-13T06:52:25.704Z" }, + { url = "https://files.pythonhosted.org/packages/8c/87/a75eb622b555708fe0427fab96056d39d4c9892b0c784b3a721088c7ee37/msgpack-1.1.1-cp313-cp313-win32.whl", hash = "sha256:500e85823a27d6d9bba1d057c871b4210c1dd6fb01fbb764e37e4e8847376323", size = 65347, upload-time = "2025-06-13T06:52:26.846Z" }, + { url = "https://files.pythonhosted.org/packages/ca/91/7dc28d5e2a11a5ad804cf2b7f7a5fcb1eb5a4966d66a5d2b41aee6376543/msgpack-1.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:6d489fba546295983abd142812bda76b57e33d0b9f5d5b71c09a583285506f69", size = 72341, upload-time = "2025-06-13T06:52:27.835Z" }, +] + +[[package]] +name = "multidict" +version = "6.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/2f/a3470242707058fe856fe59241eee5635d79087100b7042a867368863a27/multidict-6.4.4.tar.gz", hash = "sha256:69ee9e6ba214b5245031b76233dd95408a0fd57fdb019ddcc1ead4790932a8e8", size = 90183, upload-time = "2025-05-19T14:16:37.381Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/2a/e166d2ffbf4b10131b2d5b0e458f7cee7d986661caceae0de8753042d4b2/multidict-6.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:82ffabefc8d84c2742ad19c37f02cde5ec2a1ee172d19944d380f920a340e4b9", size = 64123, upload-time = "2025-05-19T14:15:11.044Z" }, + { url = "https://files.pythonhosted.org/packages/8c/96/e200e379ae5b6f95cbae472e0199ea98913f03d8c9a709f42612a432932c/multidict-6.4.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6a2f58a66fe2c22615ad26156354005391e26a2f3721c3621504cd87c1ea87bf", size = 38049, upload-time = "2025-05-19T14:15:12.902Z" }, + { url = "https://files.pythonhosted.org/packages/75/fb/47afd17b83f6a8c7fa863c6d23ac5ba6a0e6145ed8a6bcc8da20b2b2c1d2/multidict-6.4.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5883d6ee0fd9d8a48e9174df47540b7545909841ac82354c7ae4cbe9952603bd", size = 37078, upload-time = "2025-05-19T14:15:14.282Z" }, + { url = "https://files.pythonhosted.org/packages/fa/70/1af3143000eddfb19fd5ca5e78393985ed988ac493bb859800fe0914041f/multidict-6.4.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9abcf56a9511653fa1d052bfc55fbe53dbee8f34e68bd6a5a038731b0ca42d15", size = 224097, upload-time = "2025-05-19T14:15:15.566Z" }, + { url = "https://files.pythonhosted.org/packages/b1/39/d570c62b53d4fba844e0378ffbcd02ac25ca423d3235047013ba2f6f60f8/multidict-6.4.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6ed5ae5605d4ad5a049fad2a28bb7193400700ce2f4ae484ab702d1e3749c3f9", size = 230768, upload-time = "2025-05-19T14:15:17.308Z" }, + { url = "https://files.pythonhosted.org/packages/fd/f8/ed88f2c4d06f752b015933055eb291d9bc184936903752c66f68fb3c95a7/multidict-6.4.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bbfcb60396f9bcfa63e017a180c3105b8c123a63e9d1428a36544e7d37ca9e20", size = 231331, upload-time = "2025-05-19T14:15:18.73Z" }, + { url = "https://files.pythonhosted.org/packages/9c/6f/8e07cffa32f483ab887b0d56bbd8747ac2c1acd00dc0af6fcf265f4a121e/multidict-6.4.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0f1987787f5f1e2076b59692352ab29a955b09ccc433c1f6b8e8e18666f608b", size = 230169, upload-time = "2025-05-19T14:15:20.179Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2b/5dcf173be15e42f330110875a2668ddfc208afc4229097312212dc9c1236/multidict-6.4.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d0121ccce8c812047d8d43d691a1ad7641f72c4f730474878a5aeae1b8ead8c", size = 222947, upload-time = "2025-05-19T14:15:21.714Z" }, + { url = "https://files.pythonhosted.org/packages/39/75/4ddcbcebe5ebcd6faa770b629260d15840a5fc07ce8ad295a32e14993726/multidict-6.4.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:83ec4967114295b8afd120a8eec579920c882831a3e4c3331d591a8e5bfbbc0f", size = 215761, upload-time = "2025-05-19T14:15:23.242Z" }, + { url = "https://files.pythonhosted.org/packages/6a/c9/55e998ae45ff15c5608e384206aa71a11e1b7f48b64d166db400b14a3433/multidict-6.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:995f985e2e268deaf17867801b859a282e0448633f1310e3704b30616d269d69", size = 227605, upload-time = "2025-05-19T14:15:24.763Z" }, + { url = "https://files.pythonhosted.org/packages/04/49/c2404eac74497503c77071bd2e6f88c7e94092b8a07601536b8dbe99be50/multidict-6.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:d832c608f94b9f92a0ec8b7e949be7792a642b6e535fcf32f3e28fab69eeb046", size = 226144, upload-time = "2025-05-19T14:15:26.249Z" }, + { url = "https://files.pythonhosted.org/packages/62/c5/0cd0c3c6f18864c40846aa2252cd69d308699cb163e1c0d989ca301684da/multidict-6.4.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d21c1212171cf7da703c5b0b7a0e85be23b720818aef502ad187d627316d5645", size = 221100, upload-time = "2025-05-19T14:15:28.303Z" }, + { url = "https://files.pythonhosted.org/packages/71/7b/f2f3887bea71739a046d601ef10e689528d4f911d84da873b6be9194ffea/multidict-6.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:cbebaa076aaecad3d4bb4c008ecc73b09274c952cf6a1b78ccfd689e51f5a5b0", size = 232731, upload-time = "2025-05-19T14:15:30.263Z" }, + { url = "https://files.pythonhosted.org/packages/e5/b3/d9de808349df97fa75ec1372758701b5800ebad3c46ae377ad63058fbcc6/multidict-6.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:c93a6fb06cc8e5d3628b2b5fda215a5db01e8f08fc15fadd65662d9b857acbe4", size = 229637, upload-time = "2025-05-19T14:15:33.337Z" }, + { url = "https://files.pythonhosted.org/packages/5e/57/13207c16b615eb4f1745b44806a96026ef8e1b694008a58226c2d8f5f0a5/multidict-6.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8cd8f81f1310182362fb0c7898145ea9c9b08a71081c5963b40ee3e3cac589b1", size = 225594, upload-time = "2025-05-19T14:15:34.832Z" }, + { url = "https://files.pythonhosted.org/packages/3a/e4/d23bec2f70221604f5565000632c305fc8f25ba953e8ce2d8a18842b9841/multidict-6.4.4-cp313-cp313-win32.whl", hash = "sha256:3e9f1cd61a0ab857154205fb0b1f3d3ace88d27ebd1409ab7af5096e409614cd", size = 35359, upload-time = "2025-05-19T14:15:36.246Z" }, + { url = "https://files.pythonhosted.org/packages/a7/7a/cfe1a47632be861b627f46f642c1d031704cc1c0f5c0efbde2ad44aa34bd/multidict-6.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:8ffb40b74400e4455785c2fa37eba434269149ec525fc8329858c862e4b35373", size = 38903, upload-time = "2025-05-19T14:15:37.507Z" }, + { url = "https://files.pythonhosted.org/packages/68/7b/15c259b0ab49938a0a1c8f3188572802704a779ddb294edc1b2a72252e7c/multidict-6.4.4-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:6a602151dbf177be2450ef38966f4be3467d41a86c6a845070d12e17c858a156", size = 68895, upload-time = "2025-05-19T14:15:38.856Z" }, + { url = "https://files.pythonhosted.org/packages/f1/7d/168b5b822bccd88142e0a3ce985858fea612404edd228698f5af691020c9/multidict-6.4.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0d2b9712211b860d123815a80b859075d86a4d54787e247d7fbee9db6832cf1c", size = 40183, upload-time = "2025-05-19T14:15:40.197Z" }, + { url = "https://files.pythonhosted.org/packages/e0/b7/d4b8d98eb850ef28a4922ba508c31d90715fd9b9da3801a30cea2967130b/multidict-6.4.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d2fa86af59f8fc1972e121ade052145f6da22758f6996a197d69bb52f8204e7e", size = 39592, upload-time = "2025-05-19T14:15:41.508Z" }, + { url = "https://files.pythonhosted.org/packages/18/28/a554678898a19583548e742080cf55d169733baf57efc48c2f0273a08583/multidict-6.4.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50855d03e9e4d66eab6947ba688ffb714616f985838077bc4b490e769e48da51", size = 226071, upload-time = "2025-05-19T14:15:42.877Z" }, + { url = "https://files.pythonhosted.org/packages/ee/dc/7ba6c789d05c310e294f85329efac1bf5b450338d2542498db1491a264df/multidict-6.4.4-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:5bce06b83be23225be1905dcdb6b789064fae92499fbc458f59a8c0e68718601", size = 222597, upload-time = "2025-05-19T14:15:44.412Z" }, + { url = "https://files.pythonhosted.org/packages/24/4f/34eadbbf401b03768dba439be0fb94b0d187facae9142821a3d5599ccb3b/multidict-6.4.4-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:66ed0731f8e5dfd8369a883b6e564aca085fb9289aacabd9decd70568b9a30de", size = 228253, upload-time = "2025-05-19T14:15:46.474Z" }, + { url = "https://files.pythonhosted.org/packages/c0/e6/493225a3cdb0d8d80d43a94503fc313536a07dae54a3f030d279e629a2bc/multidict-6.4.4-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:329ae97fc2f56f44d91bc47fe0972b1f52d21c4b7a2ac97040da02577e2daca2", size = 226146, upload-time = "2025-05-19T14:15:48.003Z" }, + { url = "https://files.pythonhosted.org/packages/2f/70/e411a7254dc3bff6f7e6e004303b1b0591358e9f0b7c08639941e0de8bd6/multidict-6.4.4-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c27e5dcf520923d6474d98b96749e6805f7677e93aaaf62656005b8643f907ab", size = 220585, upload-time = "2025-05-19T14:15:49.546Z" }, + { url = "https://files.pythonhosted.org/packages/08/8f/beb3ae7406a619100d2b1fb0022c3bb55a8225ab53c5663648ba50dfcd56/multidict-6.4.4-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:058cc59b9e9b143cc56715e59e22941a5d868c322242278d28123a5d09cdf6b0", size = 212080, upload-time = "2025-05-19T14:15:51.151Z" }, + { url = "https://files.pythonhosted.org/packages/9c/ec/355124e9d3d01cf8edb072fd14947220f357e1c5bc79c88dff89297e9342/multidict-6.4.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:69133376bc9a03f8c47343d33f91f74a99c339e8b58cea90433d8e24bb298031", size = 226558, upload-time = "2025-05-19T14:15:52.665Z" }, + { url = "https://files.pythonhosted.org/packages/fd/22/d2b95cbebbc2ada3be3812ea9287dcc9712d7f1a012fad041770afddb2ad/multidict-6.4.4-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:d6b15c55721b1b115c5ba178c77104123745b1417527ad9641a4c5e2047450f0", size = 212168, upload-time = "2025-05-19T14:15:55.279Z" }, + { url = "https://files.pythonhosted.org/packages/4d/c5/62bfc0b2f9ce88326dbe7179f9824a939c6c7775b23b95de777267b9725c/multidict-6.4.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:a887b77f51d3d41e6e1a63cf3bc7ddf24de5939d9ff69441387dfefa58ac2e26", size = 217970, upload-time = "2025-05-19T14:15:56.806Z" }, + { url = "https://files.pythonhosted.org/packages/79/74/977cea1aadc43ff1c75d23bd5bc4768a8fac98c14e5878d6ee8d6bab743c/multidict-6.4.4-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:632a3bf8f1787f7ef7d3c2f68a7bde5be2f702906f8b5842ad6da9d974d0aab3", size = 226980, upload-time = "2025-05-19T14:15:58.313Z" }, + { url = "https://files.pythonhosted.org/packages/48/fc/cc4a1a2049df2eb84006607dc428ff237af38e0fcecfdb8a29ca47b1566c/multidict-6.4.4-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:a145c550900deb7540973c5cdb183b0d24bed6b80bf7bddf33ed8f569082535e", size = 220641, upload-time = "2025-05-19T14:15:59.866Z" }, + { url = "https://files.pythonhosted.org/packages/3b/6a/a7444d113ab918701988d4abdde373dbdfd2def7bd647207e2bf645c7eac/multidict-6.4.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:cc5d83c6619ca5c9672cb78b39ed8542f1975a803dee2cda114ff73cbb076edd", size = 221728, upload-time = "2025-05-19T14:16:01.535Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b0/fdf4c73ad1c55e0f4dbbf2aa59dd37037334091f9a4961646d2b7ac91a86/multidict-6.4.4-cp313-cp313t-win32.whl", hash = "sha256:3312f63261b9df49be9d57aaa6abf53a6ad96d93b24f9cc16cf979956355ce6e", size = 41913, upload-time = "2025-05-19T14:16:03.199Z" }, + { url = "https://files.pythonhosted.org/packages/8e/92/27989ecca97e542c0d01d05a98a5ae12198a243a9ee12563a0313291511f/multidict-6.4.4-cp313-cp313t-win_amd64.whl", hash = "sha256:ba852168d814b2c73333073e1c7116d9395bea69575a01b0b3c89d2d5a87c8fb", size = 46112, upload-time = "2025-05-19T14:16:04.909Z" }, + { url = "https://files.pythonhosted.org/packages/84/5d/e17845bb0fa76334477d5de38654d27946d5b5d3695443987a094a71b440/multidict-6.4.4-py3-none-any.whl", hash = "sha256:bd4557071b561a8b3b6075c3ce93cf9bfb6182cb241805c3d66ced3b75eff4ac", size = 10481, upload-time = "2025-05-19T14:16:36.024Z" }, +] + +[[package]] +name = "multiprocess" +version = "0.70.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, + { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, +] + +[[package]] +name = "networkx" +version = "3.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, +] + +[[package]] +name = "ninja" +version = "1.11.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/d4/6b0324541018561c5e73e617bd16f20a4fc17d1179bb3b3520b6ca8beb7b/ninja-1.11.1.4.tar.gz", hash = "sha256:6aa39f6e894e0452e5b297327db00019383ae55d5d9c57c73b04f13bf79d438a", size = 201256, upload-time = "2025-03-22T06:46:43.46Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/b1/3a61b348936b62a386465b1937cd778fa3a5748582e26d832dbab844ff27/ninja-1.11.1.4-py3-none-macosx_10_9_universal2.whl", hash = "sha256:b33923c8da88e8da20b6053e38deb433f53656441614207e01d283ad02c5e8e7", size = 279071, upload-time = "2025-03-22T06:46:17.806Z" }, + { url = "https://files.pythonhosted.org/packages/12/42/4c94fdad51fcf1f039a156e97de9e4d564c2a8cc0303782d36f9bd893a4b/ninja-1.11.1.4-py3-none-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:cede0af00b58e27b31f2482ba83292a8e9171cdb9acc2c867a3b6e40b3353e43", size = 472026, upload-time = "2025-03-22T06:46:19.974Z" }, + { url = "https://files.pythonhosted.org/packages/eb/7a/455d2877fe6cf99886849c7f9755d897df32eaf3a0fba47b56e615f880f7/ninja-1.11.1.4-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:096487995473320de7f65d622c3f1d16c3ad174797602218ca8c967f51ec38a0", size = 422814, upload-time = "2025-03-22T06:46:21.235Z" }, + { url = "https://files.pythonhosted.org/packages/e3/ad/fb6cca942528e25e8e0ab0f0cf98fe007319bf05cf69d726c564b815c4af/ninja-1.11.1.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3090d4488fadf6047d0d7a1db0c9643a8d391f0d94729554dbb89b5bdc769d7", size = 156965, upload-time = "2025-03-22T06:46:23.45Z" }, + { url = "https://files.pythonhosted.org/packages/a8/e7/d94a1b60031b115dd88526834b3da69eaacdc3c1a6769773ca8e2b1386b5/ninja-1.11.1.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecce44a00325a93631792974659cf253a815cc6da4ec96f89742925dfc295a0d", size = 179937, upload-time = "2025-03-22T06:46:24.728Z" }, + { url = "https://files.pythonhosted.org/packages/08/cc/e9316a28235409e9363794fc3d0b3083e48dd80d441006de66421e55f364/ninja-1.11.1.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9c29bb66d2aa46a2409ab369ea804c730faec7652e8c22c1e428cc09216543e5", size = 157020, upload-time = "2025-03-22T06:46:26.046Z" }, + { url = "https://files.pythonhosted.org/packages/e3/30/389b22300541aa5f2e9dad322c4de2f84be4e32aa4e8babd9160d620b5f1/ninja-1.11.1.4-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:055f386fb550c2c9d6157e45e20a84d29c47968876b9c5794ae2aec46f952306", size = 130389, upload-time = "2025-03-22T06:46:27.174Z" }, + { url = "https://files.pythonhosted.org/packages/a9/10/e27f35cb92813aabbb7ae771b1685b45be1cc8a0798ce7d4bfd08d142b93/ninja-1.11.1.4-py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:f6186d7607bb090c3be1e10c8a56b690be238f953616626f5032238c66e56867", size = 372435, upload-time = "2025-03-22T06:46:28.637Z" }, + { url = "https://files.pythonhosted.org/packages/c2/26/e3559619756739aae124c6abf7fe41f7e546ab1209cfbffb13137bff2d2e/ninja-1.11.1.4-py3-none-musllinux_1_1_i686.whl", hash = "sha256:cf4453679d15babc04ba023d68d091bb613091b67101c88f85d2171c6621c6eb", size = 419300, upload-time = "2025-03-22T06:46:30.392Z" }, + { url = "https://files.pythonhosted.org/packages/35/46/809e4e9572570991b8e6f88f3583807d017371ab4cb09171cbc72a7eb3e4/ninja-1.11.1.4-py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:d4a6f159b08b0ac4aca5ee1572e3e402f969139e71d85d37c0e2872129098749", size = 420239, upload-time = "2025-03-22T06:46:32.442Z" }, + { url = "https://files.pythonhosted.org/packages/e6/64/5cb5710d15f844edf02ada577f8eddfdcd116f47eec15850f3371a3a4b33/ninja-1.11.1.4-py3-none-musllinux_1_1_s390x.whl", hash = "sha256:c3b96bd875f3ef1db782470e9e41d7508905a0986571f219d20ffed238befa15", size = 415986, upload-time = "2025-03-22T06:46:33.821Z" }, + { url = "https://files.pythonhosted.org/packages/95/b2/0e9ab1d926f423b12b09925f78afcc5e48b3c22e7121be3ddf6c35bf06a3/ninja-1.11.1.4-py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:cf554e73f72c04deb04d0cf51f5fdb1903d9c9ca3d2344249c8ce3bd616ebc02", size = 379657, upload-time = "2025-03-22T06:46:36.166Z" }, + { url = "https://files.pythonhosted.org/packages/c8/3e/fd6d330d0434168e7fe070d414b57dd99c4c133faa69c05b42a3cbdc6c13/ninja-1.11.1.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:cfdd09776436a1ff3c4a2558d3fc50a689fb9d7f1bdbc3e6f7b8c2991341ddb3", size = 454466, upload-time = "2025-03-22T06:46:37.413Z" }, + { url = "https://files.pythonhosted.org/packages/e6/df/a25f3ad0b1c59d1b90564096e4fd89a6ca30d562b1e942f23880c3000b89/ninja-1.11.1.4-py3-none-win32.whl", hash = "sha256:2ab67a41c90bea5ec4b795bab084bc0b3b3bb69d3cd21ca0294fc0fc15a111eb", size = 255931, upload-time = "2025-03-22T06:46:39.171Z" }, + { url = "https://files.pythonhosted.org/packages/5b/10/9b8fe9ac004847490cc7b54896124c01ce2d87d95dc60aabd0b8591addff/ninja-1.11.1.4-py3-none-win_amd64.whl", hash = "sha256:4617b3c12ff64b611a7d93fd9e378275512bb36eff8babff7c83f5116b4f8d66", size = 296461, upload-time = "2025-03-22T06:46:40.532Z" }, + { url = "https://files.pythonhosted.org/packages/b9/58/612a17593c2d117f96c7f6b7f1e6570246bddc4b1e808519403a1417f217/ninja-1.11.1.4-py3-none-win_arm64.whl", hash = "sha256:5713cf50c5be50084a8693308a63ecf9e55c3132a78a41ab1363a28b6caaaee1", size = 271441, upload-time = "2025-03-22T06:46:42.147Z" }, +] + +[[package]] +name = "numpy" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/db/8e12381333aea300890829a0a36bfa738cac95475d88982d538725143fd9/numpy-2.3.0.tar.gz", hash = "sha256:581f87f9e9e9db2cba2141400e160e9dd644ee248788d6f90636eeb8fd9260a6", size = 20382813, upload-time = "2025-06-07T14:54:32.608Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/fc/1d67f751fd4dbafc5780244fe699bc4084268bad44b7c5deb0492473127b/numpy-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5754ab5595bfa2c2387d241296e0381c21f44a4b90a776c3c1d39eede13a746a", size = 20889633, upload-time = "2025-06-07T14:44:06.839Z" }, + { url = "https://files.pythonhosted.org/packages/e8/95/73ffdb69e5c3f19ec4530f8924c4386e7ba097efc94b9c0aff607178ad94/numpy-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d11fa02f77752d8099573d64e5fe33de3229b6632036ec08f7080f46b6649959", size = 14151683, upload-time = "2025-06-07T14:44:28.847Z" }, + { url = "https://files.pythonhosted.org/packages/64/d5/06d4bb31bb65a1d9c419eb5676173a2f90fd8da3c59f816cc54c640ce265/numpy-2.3.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:aba48d17e87688a765ab1cd557882052f238e2f36545dfa8e29e6a91aef77afe", size = 5102683, upload-time = "2025-06-07T14:44:38.417Z" }, + { url = "https://files.pythonhosted.org/packages/12/8b/6c2cef44f8ccdc231f6b56013dff1d71138c48124334aded36b1a1b30c5a/numpy-2.3.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4dc58865623023b63b10d52f18abaac3729346a7a46a778381e0e3af4b7f3beb", size = 6640253, upload-time = "2025-06-07T14:44:49.359Z" }, + { url = "https://files.pythonhosted.org/packages/62/aa/fca4bf8de3396ddb59544df9b75ffe5b73096174de97a9492d426f5cd4aa/numpy-2.3.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:df470d376f54e052c76517393fa443758fefcdd634645bc9c1f84eafc67087f0", size = 14258658, upload-time = "2025-06-07T14:45:10.156Z" }, + { url = "https://files.pythonhosted.org/packages/1c/12/734dce1087eed1875f2297f687e671cfe53a091b6f2f55f0c7241aad041b/numpy-2.3.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:87717eb24d4a8a64683b7a4e91ace04e2f5c7c77872f823f02a94feee186168f", size = 16628765, upload-time = "2025-06-07T14:45:35.076Z" }, + { url = "https://files.pythonhosted.org/packages/48/03/ffa41ade0e825cbcd5606a5669962419528212a16082763fc051a7247d76/numpy-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d8fa264d56882b59dcb5ea4d6ab6f31d0c58a57b41aec605848b6eb2ef4a43e8", size = 15564335, upload-time = "2025-06-07T14:45:58.797Z" }, + { url = "https://files.pythonhosted.org/packages/07/58/869398a11863310aee0ff85a3e13b4c12f20d032b90c4b3ee93c3b728393/numpy-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e651756066a0eaf900916497e20e02fe1ae544187cb0fe88de981671ee7f6270", size = 18360608, upload-time = "2025-06-07T14:46:25.687Z" }, + { url = "https://files.pythonhosted.org/packages/2f/8a/5756935752ad278c17e8a061eb2127c9a3edf4ba2c31779548b336f23c8d/numpy-2.3.0-cp313-cp313-win32.whl", hash = "sha256:e43c3cce3b6ae5f94696669ff2a6eafd9a6b9332008bafa4117af70f4b88be6f", size = 6310005, upload-time = "2025-06-07T14:50:13.138Z" }, + { url = "https://files.pythonhosted.org/packages/08/60/61d60cf0dfc0bf15381eaef46366ebc0c1a787856d1db0c80b006092af84/numpy-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:81ae0bf2564cf475f94be4a27ef7bcf8af0c3e28da46770fc904da9abd5279b5", size = 12729093, upload-time = "2025-06-07T14:50:31.82Z" }, + { url = "https://files.pythonhosted.org/packages/66/31/2f2f2d2b3e3c32d5753d01437240feaa32220b73258c9eef2e42a0832866/numpy-2.3.0-cp313-cp313-win_arm64.whl", hash = "sha256:c8738baa52505fa6e82778580b23f945e3578412554d937093eac9205e845e6e", size = 9885689, upload-time = "2025-06-07T14:50:47.888Z" }, + { url = "https://files.pythonhosted.org/packages/f1/89/c7828f23cc50f607ceb912774bb4cff225ccae7131c431398ad8400e2c98/numpy-2.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:39b27d8b38942a647f048b675f134dd5a567f95bfff481f9109ec308515c51d8", size = 20986612, upload-time = "2025-06-07T14:46:56.077Z" }, + { url = "https://files.pythonhosted.org/packages/dd/46/79ecf47da34c4c50eedec7511e53d57ffdfd31c742c00be7dc1d5ffdb917/numpy-2.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0eba4a1ea88f9a6f30f56fdafdeb8da3774349eacddab9581a21234b8535d3d3", size = 14298953, upload-time = "2025-06-07T14:47:18.053Z" }, + { url = "https://files.pythonhosted.org/packages/59/44/f6caf50713d6ff4480640bccb2a534ce1d8e6e0960c8f864947439f0ee95/numpy-2.3.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:b0f1f11d0a1da54927436505a5a7670b154eac27f5672afc389661013dfe3d4f", size = 5225806, upload-time = "2025-06-07T14:47:27.524Z" }, + { url = "https://files.pythonhosted.org/packages/a6/43/e1fd1aca7c97e234dd05e66de4ab7a5be54548257efcdd1bc33637e72102/numpy-2.3.0-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:690d0a5b60a47e1f9dcec7b77750a4854c0d690e9058b7bef3106e3ae9117808", size = 6735169, upload-time = "2025-06-07T14:47:38.057Z" }, + { url = "https://files.pythonhosted.org/packages/84/89/f76f93b06a03177c0faa7ca94d0856c4e5c4bcaf3c5f77640c9ed0303e1c/numpy-2.3.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:8b51ead2b258284458e570942137155978583e407babc22e3d0ed7af33ce06f8", size = 14330701, upload-time = "2025-06-07T14:47:59.113Z" }, + { url = "https://files.pythonhosted.org/packages/aa/f5/4858c3e9ff7a7d64561b20580cf7cc5d085794bd465a19604945d6501f6c/numpy-2.3.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:aaf81c7b82c73bd9b45e79cfb9476cb9c29e937494bfe9092c26aece812818ad", size = 16692983, upload-time = "2025-06-07T14:48:24.196Z" }, + { url = "https://files.pythonhosted.org/packages/08/17/0e3b4182e691a10e9483bcc62b4bb8693dbf9ea5dc9ba0b77a60435074bb/numpy-2.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f420033a20b4f6a2a11f585f93c843ac40686a7c3fa514060a97d9de93e5e72b", size = 15641435, upload-time = "2025-06-07T14:48:47.712Z" }, + { url = "https://files.pythonhosted.org/packages/4e/d5/463279fda028d3c1efa74e7e8d507605ae87f33dbd0543cf4c4527c8b882/numpy-2.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d344ca32ab482bcf8735d8f95091ad081f97120546f3d250240868430ce52555", size = 18433798, upload-time = "2025-06-07T14:49:14.866Z" }, + { url = "https://files.pythonhosted.org/packages/0e/1e/7a9d98c886d4c39a2b4d3a7c026bffcf8fbcaf518782132d12a301cfc47a/numpy-2.3.0-cp313-cp313t-win32.whl", hash = "sha256:48a2e8eaf76364c32a1feaa60d6925eaf32ed7a040183b807e02674305beef61", size = 6438632, upload-time = "2025-06-07T14:49:25.67Z" }, + { url = "https://files.pythonhosted.org/packages/fe/ab/66fc909931d5eb230107d016861824f335ae2c0533f422e654e5ff556784/numpy-2.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ba17f93a94e503551f154de210e4d50c5e3ee20f7e7a1b5f6ce3f22d419b93bb", size = 12868491, upload-time = "2025-06-07T14:49:44.898Z" }, + { url = "https://files.pythonhosted.org/packages/ee/e8/2c8a1c9e34d6f6d600c83d5ce5b71646c32a13f34ca5c518cc060639841c/numpy-2.3.0-cp313-cp313t-win_arm64.whl", hash = "sha256:f14e016d9409680959691c109be98c436c6249eaf7f118b424679793607b5944", size = 9935345, upload-time = "2025-06-07T14:50:02.311Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.6.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/eb/ff4b8c503fa1f1796679dce648854d58751982426e4e4b37d6fce49d259c/nvidia_cublas_cu12-12.6.4.1-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:08ed2686e9875d01b58e3cb379c6896df8e76c75e0d4a7f7dace3d7b6d9ef8eb", size = 393138322, upload-time = "2024-11-20T17:40:25.65Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.6.80" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/60/7b6497946d74bcf1de852a21824d63baad12cd417db4195fc1bfe59db953/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6768bad6cab4f19e8292125e5f1ac8aa7d1718704012a0e3272a6f61c4bce132", size = 8917980, upload-time = "2024-11-20T17:36:04.019Z" }, + { url = "https://files.pythonhosted.org/packages/a5/24/120ee57b218d9952c379d1e026c4479c9ece9997a4fb46303611ee48f038/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a3eff6cdfcc6a4c35db968a06fcadb061cbc7d6dde548609a941ff8701b98b73", size = 8917972, upload-time = "2024-10-01T16:58:06.036Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.6.77" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/2e/46030320b5a80661e88039f59060d1790298b4718944a65a7f2aeda3d9e9/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:35b0cc6ee3a9636d5409133e79273ce1f3fd087abb0532d2d2e8fff1fe9efc53", size = 23650380, upload-time = "2024-10-01T17:00:14.643Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.6.77" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/23/e717c5ac26d26cf39a27fbc076240fad2e3b817e5889d671b67f4f9f49c5/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ba3b56a4f896141e25e19ab287cd71e52a6a0f4b29d0d31609f60e3b4d5219b7", size = 897690, upload-time = "2024-11-20T17:35:30.697Z" }, + { url = "https://files.pythonhosted.org/packages/f0/62/65c05e161eeddbafeca24dc461f47de550d9fa8a7e04eb213e32b55cfd99/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a84d15d5e1da416dd4774cb42edf5e954a3e60cc945698dc1d5be02321c44dc8", size = 897678, upload-time = "2024-10-01T16:57:33.821Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.5.1.17" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/78/4535c9c7f859a64781e43c969a3a7e84c54634e319a996d43ef32ce46f83/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:30ac3869f6db17d170e0e556dd6cc5eee02647abc31ca856634d5a40f82c15b2", size = 570988386, upload-time = "2024-10-25T19:54:26.39Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/16/73727675941ab8e6ffd86ca3a4b7b47065edcca7a997920b831f8147c99d/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ccba62eb9cef5559abd5e0d54ceed2d9934030f51163df018532142a8ec533e5", size = 200221632, upload-time = "2024-11-20T17:41:32.357Z" }, + { url = "https://files.pythonhosted.org/packages/60/de/99ec247a07ea40c969d904fc14f3a356b3e2a704121675b75c366b694ee1/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl", hash = "sha256:768160ac89f6f7b459bee747e8d175dbf53619cfe74b2a5636264163138013ca", size = 200221622, upload-time = "2024-10-01T17:03:58.79Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.11.1.6" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/66/cc9876340ac68ae71b15c743ddb13f8b30d5244af344ec8322b449e35426/nvidia_cufile_cu12-1.11.1.6-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cc23469d1c7e52ce6c1d55253273d32c565dd22068647f3aa59b3c6b005bf159", size = 1142103, upload-time = "2024-11-20T17:42:11.83Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.7.77" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/1b/44a01c4e70933637c93e6e1a8063d1e998b50213a6b65ac5a9169c47e98e/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a42cd1344297f70b9e39a1e4f467a4e1c10f1da54ff7a85c12197f6c652c8bdf", size = 56279010, upload-time = "2024-11-20T17:42:50.958Z" }, + { url = "https://files.pythonhosted.org/packages/4a/aa/2c7ff0b5ee02eaef890c0ce7d4f74bc30901871c5e45dee1ae6d0083cd80/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:99f1a32f1ac2bd134897fc7a203f779303261268a65762a623bf30cc9fe79117", size = 56279000, upload-time = "2024-10-01T17:04:45.274Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/6e/c2cf12c9ff8b872e92b4a5740701e51ff17689c4d726fca91875b07f655d/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9e49843a7707e42022babb9bcfa33c29857a93b88020c4e4434656a655b698c", size = 158229790, upload-time = "2024-11-20T17:43:43.211Z" }, + { url = "https://files.pythonhosted.org/packages/9f/81/baba53585da791d043c10084cf9553e074548408e04ae884cfe9193bd484/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6cf28f17f64107a0c4d7802be5ff5537b2130bfc112f25d5a30df227058ca0e6", size = 158229780, upload-time = "2024-10-01T17:05:39.875Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/1e/b8b7c2f4099a37b96af5c9bb158632ea9e5d9d27d7391d7eb8fc45236674/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7556d9eca156e18184b94947ade0fba5bb47d69cec46bf8660fd2c71a4b48b73", size = 216561367, upload-time = "2024-11-20T17:44:54.824Z" }, + { url = "https://files.pythonhosted.org/packages/43/ac/64c4316ba163e8217a99680c7605f779accffc6a4bcd0c778c12948d3707/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:23749a6571191a215cb74d1cdbff4a86e7b19f1200c071b3fcf844a5bea23a2f", size = 216561357, upload-time = "2024-10-01T17:06:29.861Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/9a/72ef35b399b0e183bc2e8f6f558036922d453c4d8237dab26c666a04244b/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e5c8a26c36445dd2e6812f1177978a24e2d37cacce7e090f297a688d1ec44f46", size = 156785796, upload-time = "2024-10-15T21:29:17.709Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.26.2" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/ca/f42388aed0fddd64ade7493dbba36e1f534d4e6fdbdd355c6a90030ae028/nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:694cf3879a206553cc9d7dbda76b13efaf610fdb70a50cba303de1b0d1530ac6", size = 201319755, upload-time = "2025-03-13T00:29:55.296Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.6.85" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/d7/c5383e47c7e9bf1c99d5bd2a8c935af2b6d705ad831a7ec5c97db4d82f4f/nvidia_nvjitlink_cu12-12.6.85-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:eedc36df9e88b682efe4309aa16b5b4e78c2407eac59e8c10a6a47535164369a", size = 19744971, upload-time = "2024-11-20T17:46:53.366Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.6.77" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/9a/fff8376f8e3d084cd1530e1ef7b879bb7d6d265620c95c1b322725c694f4/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b90bed3df379fa79afbd21be8e04a0314336b8ae16768b58f2d34cb1d04cd7d2", size = 89276, upload-time = "2024-11-20T17:38:27.621Z" }, + { url = "https://files.pythonhosted.org/packages/9e/4e/0d0c945463719429b7bd21dece907ad0bde437a2ff12b9b12fee94722ab0/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl", hash = "sha256:6574241a3ec5fdc9334353ab8c479fe75841dbe8f4532a8fc97ce63503330ba1", size = 89265, upload-time = "2024-10-01T17:00:38.172Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pandas" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/51/48f713c4c728d7c55ef7444ba5ea027c26998d96d1a40953b346438602fc/pandas-2.3.0.tar.gz", hash = "sha256:34600ab34ebf1131a7613a260a61dbe8b62c188ec0ea4c296da7c9a06b004133", size = 4484490, upload-time = "2025-06-05T03:27:54.133Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/57/5cb75a56a4842bbd0511c3d1c79186d8315b82dac802118322b2de1194fe/pandas-2.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2c7e2fc25f89a49a11599ec1e76821322439d90820108309bf42130d2f36c983", size = 11518913, upload-time = "2025-06-05T03:27:02.757Z" }, + { url = "https://files.pythonhosted.org/packages/05/01/0c8785610e465e4948a01a059562176e4c8088aa257e2e074db868f86d4e/pandas-2.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c6da97aeb6a6d233fb6b17986234cc723b396b50a3c6804776351994f2a658fd", size = 10655249, upload-time = "2025-06-05T16:50:20.17Z" }, + { url = "https://files.pythonhosted.org/packages/e8/6a/47fd7517cd8abe72a58706aab2b99e9438360d36dcdb052cf917b7bf3bdc/pandas-2.3.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb32dc743b52467d488e7a7c8039b821da2826a9ba4f85b89ea95274f863280f", size = 11328359, upload-time = "2025-06-05T03:27:06.431Z" }, + { url = "https://files.pythonhosted.org/packages/2a/b3/463bfe819ed60fb7e7ddffb4ae2ee04b887b3444feee6c19437b8f834837/pandas-2.3.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:213cd63c43263dbb522c1f8a7c9d072e25900f6975596f883f4bebd77295d4f3", size = 12024789, upload-time = "2025-06-05T03:27:09.875Z" }, + { url = "https://files.pythonhosted.org/packages/04/0c/e0704ccdb0ac40aeb3434d1c641c43d05f75c92e67525df39575ace35468/pandas-2.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1d2b33e68d0ce64e26a4acc2e72d747292084f4e8db4c847c6f5f6cbe56ed6d8", size = 12480734, upload-time = "2025-06-06T00:00:22.246Z" }, + { url = "https://files.pythonhosted.org/packages/e9/df/815d6583967001153bb27f5cf075653d69d51ad887ebbf4cfe1173a1ac58/pandas-2.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:430a63bae10b5086995db1b02694996336e5a8ac9a96b4200572b413dfdfccb9", size = 13223381, upload-time = "2025-06-05T03:27:15.641Z" }, + { url = "https://files.pythonhosted.org/packages/79/88/ca5973ed07b7f484c493e941dbff990861ca55291ff7ac67c815ce347395/pandas-2.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:4930255e28ff5545e2ca404637bcc56f031893142773b3468dc021c6c32a1390", size = 10970135, upload-time = "2025-06-05T03:27:24.131Z" }, + { url = "https://files.pythonhosted.org/packages/24/fb/0994c14d1f7909ce83f0b1fb27958135513c4f3f2528bde216180aa73bfc/pandas-2.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:f925f1ef673b4bd0271b1809b72b3270384f2b7d9d14a189b12b7fc02574d575", size = 12141356, upload-time = "2025-06-05T03:27:34.547Z" }, + { url = "https://files.pythonhosted.org/packages/9d/a2/9b903e5962134497ac4f8a96f862ee3081cb2506f69f8e4778ce3d9c9d82/pandas-2.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78ad363ddb873a631e92a3c063ade1ecfb34cae71e9a2be6ad100f875ac1042", size = 11474674, upload-time = "2025-06-05T03:27:39.448Z" }, + { url = "https://files.pythonhosted.org/packages/81/3a/3806d041bce032f8de44380f866059437fb79e36d6b22c82c187e65f765b/pandas-2.3.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:951805d146922aed8357e4cc5671b8b0b9be1027f0619cea132a9f3f65f2f09c", size = 11439876, upload-time = "2025-06-05T03:27:43.652Z" }, + { url = "https://files.pythonhosted.org/packages/15/aa/3fc3181d12b95da71f5c2537c3e3b3af6ab3a8c392ab41ebb766e0929bc6/pandas-2.3.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a881bc1309f3fce34696d07b00f13335c41f5f5a8770a33b09ebe23261cfc67", size = 11966182, upload-time = "2025-06-05T03:27:47.652Z" }, + { url = "https://files.pythonhosted.org/packages/37/e7/e12f2d9b0a2c4a2cc86e2aabff7ccfd24f03e597d770abfa2acd313ee46b/pandas-2.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e1991bbb96f4050b09b5f811253c4f3cf05ee89a589379aa36cd623f21a31d6f", size = 12547686, upload-time = "2025-06-06T00:00:26.142Z" }, + { url = "https://files.pythonhosted.org/packages/39/c2/646d2e93e0af70f4e5359d870a63584dacbc324b54d73e6b3267920ff117/pandas-2.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bb3be958022198531eb7ec2008cfc78c5b1eed51af8600c6c5d9160d89d8d249", size = 13231847, upload-time = "2025-06-05T03:27:51.465Z" }, +] + +[[package]] +name = "peft" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "accelerate" }, + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, + { name = "tqdm" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/65/faa18cd8ffbe0f742c3f2559770646cce2574b9cd28a2a05e8d36f64e968/peft-0.15.2.tar.gz", hash = "sha256:7059029f4d42a092ded1aa117dd366a46084aef638bdd593f6ab0195d5427fcd", size = 472952, upload-time = "2025-04-15T15:27:53.09Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/85/8e6ea3d1089f2b6de3c1cd34bbbd7560912af9d34b057be3b8b8fefe1da3/peft-0.15.2-py3-none-any.whl", hash = "sha256:0dfc942b03b7af4b7267cd4e30b15e3a4a1d277adc581ce6245fc13f1f93d0a0", size = 411051, upload-time = "2025-04-15T15:27:50.799Z" }, +] + +[[package]] +name = "propcache" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/16/43264e4a779dd8588c21a70f0709665ee8f611211bdd2c87d952cfa7c776/propcache-0.3.2.tar.gz", hash = "sha256:20d7d62e4e7ef05f221e0db2856b979540686342e7dd9973b815599c7057e168", size = 44139, upload-time = "2025-06-09T22:56:06.081Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/d1/8c747fafa558c603c4ca19d8e20b288aa0c7cda74e9402f50f31eb65267e/propcache-0.3.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ca592ed634a73ca002967458187109265e980422116c0a107cf93d81f95af945", size = 71286, upload-time = "2025-06-09T22:54:54.369Z" }, + { url = "https://files.pythonhosted.org/packages/61/99/d606cb7986b60d89c36de8a85d58764323b3a5ff07770a99d8e993b3fa73/propcache-0.3.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9ecb0aad4020e275652ba3975740f241bd12a61f1a784df044cf7477a02bc252", size = 42425, upload-time = "2025-06-09T22:54:55.642Z" }, + { url = "https://files.pythonhosted.org/packages/8c/96/ef98f91bbb42b79e9bb82bdd348b255eb9d65f14dbbe3b1594644c4073f7/propcache-0.3.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7f08f1cc28bd2eade7a8a3d2954ccc673bb02062e3e7da09bc75d843386b342f", size = 41846, upload-time = "2025-06-09T22:54:57.246Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ad/3f0f9a705fb630d175146cd7b1d2bf5555c9beaed54e94132b21aac098a6/propcache-0.3.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1a342c834734edb4be5ecb1e9fb48cb64b1e2320fccbd8c54bf8da8f2a84c33", size = 208871, upload-time = "2025-06-09T22:54:58.975Z" }, + { url = "https://files.pythonhosted.org/packages/3a/38/2085cda93d2c8b6ec3e92af2c89489a36a5886b712a34ab25de9fbca7992/propcache-0.3.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8a544caaae1ac73f1fecfae70ded3e93728831affebd017d53449e3ac052ac1e", size = 215720, upload-time = "2025-06-09T22:55:00.471Z" }, + { url = "https://files.pythonhosted.org/packages/61/c1/d72ea2dc83ac7f2c8e182786ab0fc2c7bd123a1ff9b7975bee671866fe5f/propcache-0.3.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310d11aa44635298397db47a3ebce7db99a4cc4b9bbdfcf6c98a60c8d5261cf1", size = 215203, upload-time = "2025-06-09T22:55:01.834Z" }, + { url = "https://files.pythonhosted.org/packages/af/81/b324c44ae60c56ef12007105f1460d5c304b0626ab0cc6b07c8f2a9aa0b8/propcache-0.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c1396592321ac83157ac03a2023aa6cc4a3cc3cfdecb71090054c09e5a7cce3", size = 206365, upload-time = "2025-06-09T22:55:03.199Z" }, + { url = "https://files.pythonhosted.org/packages/09/73/88549128bb89e66d2aff242488f62869014ae092db63ccea53c1cc75a81d/propcache-0.3.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8cabf5b5902272565e78197edb682017d21cf3b550ba0460ee473753f28d23c1", size = 196016, upload-time = "2025-06-09T22:55:04.518Z" }, + { url = "https://files.pythonhosted.org/packages/b9/3f/3bdd14e737d145114a5eb83cb172903afba7242f67c5877f9909a20d948d/propcache-0.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0a2f2235ac46a7aa25bdeb03a9e7060f6ecbd213b1f9101c43b3090ffb971ef6", size = 205596, upload-time = "2025-06-09T22:55:05.942Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ca/2f4aa819c357d3107c3763d7ef42c03980f9ed5c48c82e01e25945d437c1/propcache-0.3.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:92b69e12e34869a6970fd2f3da91669899994b47c98f5d430b781c26f1d9f387", size = 200977, upload-time = "2025-06-09T22:55:07.792Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4a/e65276c7477533c59085251ae88505caf6831c0e85ff8b2e31ebcbb949b1/propcache-0.3.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:54e02207c79968ebbdffc169591009f4474dde3b4679e16634d34c9363ff56b4", size = 197220, upload-time = "2025-06-09T22:55:09.173Z" }, + { url = "https://files.pythonhosted.org/packages/7c/54/fc7152e517cf5578278b242396ce4d4b36795423988ef39bb8cd5bf274c8/propcache-0.3.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4adfb44cb588001f68c5466579d3f1157ca07f7504fc91ec87862e2b8e556b88", size = 210642, upload-time = "2025-06-09T22:55:10.62Z" }, + { url = "https://files.pythonhosted.org/packages/b9/80/abeb4a896d2767bf5f1ea7b92eb7be6a5330645bd7fb844049c0e4045d9d/propcache-0.3.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:fd3e6019dc1261cd0291ee8919dd91fbab7b169bb76aeef6c716833a3f65d206", size = 212789, upload-time = "2025-06-09T22:55:12.029Z" }, + { url = "https://files.pythonhosted.org/packages/b3/db/ea12a49aa7b2b6d68a5da8293dcf50068d48d088100ac016ad92a6a780e6/propcache-0.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4c181cad81158d71c41a2bce88edce078458e2dd5ffee7eddd6b05da85079f43", size = 205880, upload-time = "2025-06-09T22:55:13.45Z" }, + { url = "https://files.pythonhosted.org/packages/d1/e5/9076a0bbbfb65d1198007059c65639dfd56266cf8e477a9707e4b1999ff4/propcache-0.3.2-cp313-cp313-win32.whl", hash = "sha256:8a08154613f2249519e549de2330cf8e2071c2887309a7b07fb56098f5170a02", size = 37220, upload-time = "2025-06-09T22:55:15.284Z" }, + { url = "https://files.pythonhosted.org/packages/d3/f5/b369e026b09a26cd77aa88d8fffd69141d2ae00a2abaaf5380d2603f4b7f/propcache-0.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:e41671f1594fc4ab0a6dec1351864713cb3a279910ae8b58f884a88a0a632c05", size = 40678, upload-time = "2025-06-09T22:55:16.445Z" }, + { url = "https://files.pythonhosted.org/packages/a4/3a/6ece377b55544941a08d03581c7bc400a3c8cd3c2865900a68d5de79e21f/propcache-0.3.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:9a3cf035bbaf035f109987d9d55dc90e4b0e36e04bbbb95af3055ef17194057b", size = 76560, upload-time = "2025-06-09T22:55:17.598Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/64a2bb16418740fa634b0e9c3d29edff1db07f56d3546ca2d86ddf0305e1/propcache-0.3.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:156c03d07dc1323d8dacaa221fbe028c5c70d16709cdd63502778e6c3ccca1b0", size = 44676, upload-time = "2025-06-09T22:55:18.922Z" }, + { url = "https://files.pythonhosted.org/packages/36/7b/f025e06ea51cb72c52fb87e9b395cced02786610b60a3ed51da8af017170/propcache-0.3.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74413c0ba02ba86f55cf60d18daab219f7e531620c15f1e23d95563f505efe7e", size = 44701, upload-time = "2025-06-09T22:55:20.106Z" }, + { url = "https://files.pythonhosted.org/packages/a4/00/faa1b1b7c3b74fc277f8642f32a4c72ba1d7b2de36d7cdfb676db7f4303e/propcache-0.3.2-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f066b437bb3fa39c58ff97ab2ca351db465157d68ed0440abecb21715eb24b28", size = 276934, upload-time = "2025-06-09T22:55:21.5Z" }, + { url = "https://files.pythonhosted.org/packages/74/ab/935beb6f1756e0476a4d5938ff44bf0d13a055fed880caf93859b4f1baf4/propcache-0.3.2-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1304b085c83067914721e7e9d9917d41ad87696bf70f0bc7dee450e9c71ad0a", size = 278316, upload-time = "2025-06-09T22:55:22.918Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9d/994a5c1ce4389610838d1caec74bdf0e98b306c70314d46dbe4fcf21a3e2/propcache-0.3.2-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ab50cef01b372763a13333b4e54021bdcb291fc9a8e2ccb9c2df98be51bcde6c", size = 282619, upload-time = "2025-06-09T22:55:24.651Z" }, + { url = "https://files.pythonhosted.org/packages/2b/00/a10afce3d1ed0287cef2e09506d3be9822513f2c1e96457ee369adb9a6cd/propcache-0.3.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fad3b2a085ec259ad2c2842666b2a0a49dea8463579c606426128925af1ed725", size = 265896, upload-time = "2025-06-09T22:55:26.049Z" }, + { url = "https://files.pythonhosted.org/packages/2e/a8/2aa6716ffa566ca57c749edb909ad27884680887d68517e4be41b02299f3/propcache-0.3.2-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:261fa020c1c14deafd54c76b014956e2f86991af198c51139faf41c4d5e83892", size = 252111, upload-time = "2025-06-09T22:55:27.381Z" }, + { url = "https://files.pythonhosted.org/packages/36/4f/345ca9183b85ac29c8694b0941f7484bf419c7f0fea2d1e386b4f7893eed/propcache-0.3.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:46d7f8aa79c927e5f987ee3a80205c987717d3659f035c85cf0c3680526bdb44", size = 268334, upload-time = "2025-06-09T22:55:28.747Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ca/fcd54f78b59e3f97b3b9715501e3147f5340167733d27db423aa321e7148/propcache-0.3.2-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:6d8f3f0eebf73e3c0ff0e7853f68be638b4043c65a70517bb575eff54edd8dbe", size = 255026, upload-time = "2025-06-09T22:55:30.184Z" }, + { url = "https://files.pythonhosted.org/packages/8b/95/8e6a6bbbd78ac89c30c225210a5c687790e532ba4088afb8c0445b77ef37/propcache-0.3.2-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:03c89c1b14a5452cf15403e291c0ccd7751d5b9736ecb2c5bab977ad6c5bcd81", size = 250724, upload-time = "2025-06-09T22:55:31.646Z" }, + { url = "https://files.pythonhosted.org/packages/ee/b0/0dd03616142baba28e8b2d14ce5df6631b4673850a3d4f9c0f9dd714a404/propcache-0.3.2-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:0cc17efde71e12bbaad086d679ce575268d70bc123a5a71ea7ad76f70ba30bba", size = 268868, upload-time = "2025-06-09T22:55:33.209Z" }, + { url = "https://files.pythonhosted.org/packages/c5/98/2c12407a7e4fbacd94ddd32f3b1e3d5231e77c30ef7162b12a60e2dd5ce3/propcache-0.3.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:acdf05d00696bc0447e278bb53cb04ca72354e562cf88ea6f9107df8e7fd9770", size = 271322, upload-time = "2025-06-09T22:55:35.065Z" }, + { url = "https://files.pythonhosted.org/packages/35/91/9cb56efbb428b006bb85db28591e40b7736847b8331d43fe335acf95f6c8/propcache-0.3.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4445542398bd0b5d32df908031cb1b30d43ac848e20470a878b770ec2dcc6330", size = 265778, upload-time = "2025-06-09T22:55:36.45Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4c/b0fe775a2bdd01e176b14b574be679d84fc83958335790f7c9a686c1f468/propcache-0.3.2-cp313-cp313t-win32.whl", hash = "sha256:f86e5d7cd03afb3a1db8e9f9f6eff15794e79e791350ac48a8c924e6f439f394", size = 41175, upload-time = "2025-06-09T22:55:38.436Z" }, + { url = "https://files.pythonhosted.org/packages/a4/ff/47f08595e3d9b5e149c150f88d9714574f1a7cbd89fe2817158a952674bf/propcache-0.3.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9704bedf6e7cbe3c65eca4379a9b53ee6a83749f047808cbb5044d40d7d72198", size = 44857, upload-time = "2025-06-09T22:55:39.687Z" }, + { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, +] + +[[package]] +name = "psutil" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003, upload-time = "2025-02-13T21:54:07.946Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051, upload-time = "2025-02-13T21:54:12.36Z" }, + { url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535, upload-time = "2025-02-13T21:54:16.07Z" }, + { url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004, upload-time = "2025-02-13T21:54:18.662Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986, upload-time = "2025-02-13T21:54:21.811Z" }, + { url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544, upload-time = "2025-02-13T21:54:24.68Z" }, + { url = "https://files.pythonhosted.org/packages/50/e6/eecf58810b9d12e6427369784efe814a1eec0f492084ce8eb8f4d89d6d61/psutil-7.0.0-cp37-abi3-win32.whl", hash = "sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99", size = 241053, upload-time = "2025-02-13T21:54:34.31Z" }, + { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, +] + +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + +[[package]] +name = "pyarrow" +version = "20.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/ee/a7810cb9f3d6e9238e61d312076a9859bf3668fd21c69744de9532383912/pyarrow-20.0.0.tar.gz", hash = "sha256:febc4a913592573c8d5805091a6c2b5064c8bd6e002131f01061797d91c783c1", size = 1125187, upload-time = "2025-04-27T12:34:23.264Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/aa/daa413b81446d20d4dad2944110dcf4cf4f4179ef7f685dd5a6d7570dc8e/pyarrow-20.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:a15532e77b94c61efadde86d10957950392999503b3616b2ffcef7621a002893", size = 30798501, upload-time = "2025-04-27T12:30:48.351Z" }, + { url = "https://files.pythonhosted.org/packages/ff/75/2303d1caa410925de902d32ac215dc80a7ce7dd8dfe95358c165f2adf107/pyarrow-20.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:dd43f58037443af715f34f1322c782ec463a3c8a94a85fdb2d987ceb5658e061", size = 32277895, upload-time = "2025-04-27T12:30:55.238Z" }, + { url = "https://files.pythonhosted.org/packages/92/41/fe18c7c0b38b20811b73d1bdd54b1fccba0dab0e51d2048878042d84afa8/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aa0d288143a8585806e3cc7c39566407aab646fb9ece164609dac1cfff45f6ae", size = 41327322, upload-time = "2025-04-27T12:31:05.587Z" }, + { url = "https://files.pythonhosted.org/packages/da/ab/7dbf3d11db67c72dbf36ae63dcbc9f30b866c153b3a22ef728523943eee6/pyarrow-20.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6953f0114f8d6f3d905d98e987d0924dabce59c3cda380bdfaa25a6201563b4", size = 42411441, upload-time = "2025-04-27T12:31:15.675Z" }, + { url = "https://files.pythonhosted.org/packages/90/c3/0c7da7b6dac863af75b64e2f827e4742161128c350bfe7955b426484e226/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:991f85b48a8a5e839b2128590ce07611fae48a904cae6cab1f089c5955b57eb5", size = 40677027, upload-time = "2025-04-27T12:31:24.631Z" }, + { url = "https://files.pythonhosted.org/packages/be/27/43a47fa0ff9053ab5203bb3faeec435d43c0d8bfa40179bfd076cdbd4e1c/pyarrow-20.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:97c8dc984ed09cb07d618d57d8d4b67a5100a30c3818c2fb0b04599f0da2de7b", size = 42281473, upload-time = "2025-04-27T12:31:31.311Z" }, + { url = "https://files.pythonhosted.org/packages/bc/0b/d56c63b078876da81bbb9ba695a596eabee9b085555ed12bf6eb3b7cab0e/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9b71daf534f4745818f96c214dbc1e6124d7daf059167330b610fc69b6f3d3e3", size = 42893897, upload-time = "2025-04-27T12:31:39.406Z" }, + { url = "https://files.pythonhosted.org/packages/92/ac/7d4bd020ba9145f354012838692d48300c1b8fe5634bfda886abcada67ed/pyarrow-20.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e8b88758f9303fa5a83d6c90e176714b2fd3852e776fc2d7e42a22dd6c2fb368", size = 44543847, upload-time = "2025-04-27T12:31:45.997Z" }, + { url = "https://files.pythonhosted.org/packages/9d/07/290f4abf9ca702c5df7b47739c1b2c83588641ddfa2cc75e34a301d42e55/pyarrow-20.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:30b3051b7975801c1e1d387e17c588d8ab05ced9b1e14eec57915f79869b5031", size = 25653219, upload-time = "2025-04-27T12:31:54.11Z" }, + { url = "https://files.pythonhosted.org/packages/95/df/720bb17704b10bd69dde086e1400b8eefb8f58df3f8ac9cff6c425bf57f1/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:ca151afa4f9b7bc45bcc791eb9a89e90a9eb2772767d0b1e5389609c7d03db63", size = 30853957, upload-time = "2025-04-27T12:31:59.215Z" }, + { url = "https://files.pythonhosted.org/packages/d9/72/0d5f875efc31baef742ba55a00a25213a19ea64d7176e0fe001c5d8b6e9a/pyarrow-20.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:4680f01ecd86e0dd63e39eb5cd59ef9ff24a9d166db328679e36c108dc993d4c", size = 32247972, upload-time = "2025-04-27T12:32:05.369Z" }, + { url = "https://files.pythonhosted.org/packages/d5/bc/e48b4fa544d2eea72f7844180eb77f83f2030b84c8dad860f199f94307ed/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f4c8534e2ff059765647aa69b75d6543f9fef59e2cd4c6d18015192565d2b70", size = 41256434, upload-time = "2025-04-27T12:32:11.814Z" }, + { url = "https://files.pythonhosted.org/packages/c3/01/974043a29874aa2cf4f87fb07fd108828fc7362300265a2a64a94965e35b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1f8a47f4b4ae4c69c4d702cfbdfe4d41e18e5c7ef6f1bb1c50918c1e81c57b", size = 42353648, upload-time = "2025-04-27T12:32:20.766Z" }, + { url = "https://files.pythonhosted.org/packages/68/95/cc0d3634cde9ca69b0e51cbe830d8915ea32dda2157560dda27ff3b3337b/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:a1f60dc14658efaa927f8214734f6a01a806d7690be4b3232ba526836d216122", size = 40619853, upload-time = "2025-04-27T12:32:28.1Z" }, + { url = "https://files.pythonhosted.org/packages/29/c2/3ad40e07e96a3e74e7ed7cc8285aadfa84eb848a798c98ec0ad009eb6bcc/pyarrow-20.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:204a846dca751428991346976b914d6d2a82ae5b8316a6ed99789ebf976551e6", size = 42241743, upload-time = "2025-04-27T12:32:35.792Z" }, + { url = "https://files.pythonhosted.org/packages/eb/cb/65fa110b483339add6a9bc7b6373614166b14e20375d4daa73483755f830/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f3b117b922af5e4c6b9a9115825726cac7d8b1421c37c2b5e24fbacc8930612c", size = 42839441, upload-time = "2025-04-27T12:32:46.64Z" }, + { url = "https://files.pythonhosted.org/packages/98/7b/f30b1954589243207d7a0fbc9997401044bf9a033eec78f6cb50da3f304a/pyarrow-20.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e724a3fd23ae5b9c010e7be857f4405ed5e679db5c93e66204db1a69f733936a", size = 44503279, upload-time = "2025-04-27T12:32:56.503Z" }, + { url = "https://files.pythonhosted.org/packages/37/40/ad395740cd641869a13bcf60851296c89624662575621968dcfafabaa7f6/pyarrow-20.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:82f1ee5133bd8f49d31be1299dc07f585136679666b502540db854968576faf9", size = 25944982, upload-time = "2025-04-27T12:33:04.72Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631, upload-time = "2024-08-06T20:33:50.674Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/e3/3af305b830494fa85d95f6d95ef7fa73f2ee1cc8ef5b495c7c3269fb835f/PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba", size = 181309, upload-time = "2024-08-06T20:32:43.4Z" }, + { url = "https://files.pythonhosted.org/packages/45/9f/3b1c20a0b7a3200524eb0076cc027a970d320bd3a6592873c85c92a08731/PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1", size = 171679, upload-time = "2024-08-06T20:32:44.801Z" }, + { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428, upload-time = "2024-08-06T20:32:46.432Z" }, + { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361, upload-time = "2024-08-06T20:32:51.188Z" }, + { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523, upload-time = "2024-08-06T20:32:53.019Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660, upload-time = "2024-08-06T20:32:54.708Z" }, + { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597, upload-time = "2024-08-06T20:32:56.985Z" }, + { url = "https://files.pythonhosted.org/packages/14/0d/e2c3b43bbce3cf6bd97c840b46088a3031085179e596d4929729d8d68270/PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183", size = 140527, upload-time = "2024-08-06T20:33:03.001Z" }, + { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, +] + +[[package]] +name = "regex" +version = "2024.11.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/5f/bd69653fbfb76cf8604468d3b4ec4c403197144c7bfe0e6a5fc9e02a07cb/regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519", size = 399494, upload-time = "2024-11-06T20:12:31.635Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/73/bcb0e36614601016552fa9344544a3a2ae1809dc1401b100eab02e772e1f/regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84", size = 483525, upload-time = "2024-11-06T20:10:45.19Z" }, + { url = "https://files.pythonhosted.org/packages/0f/3f/f1a082a46b31e25291d830b369b6b0c5576a6f7fb89d3053a354c24b8a83/regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4", size = 288324, upload-time = "2024-11-06T20:10:47.177Z" }, + { url = "https://files.pythonhosted.org/packages/09/c9/4e68181a4a652fb3ef5099e077faf4fd2a694ea6e0f806a7737aff9e758a/regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0", size = 284617, upload-time = "2024-11-06T20:10:49.312Z" }, + { url = "https://files.pythonhosted.org/packages/fc/fd/37868b75eaf63843165f1d2122ca6cb94bfc0271e4428cf58c0616786dce/regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0", size = 795023, upload-time = "2024-11-06T20:10:51.102Z" }, + { url = "https://files.pythonhosted.org/packages/c4/7c/d4cd9c528502a3dedb5c13c146e7a7a539a3853dc20209c8e75d9ba9d1b2/regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7", size = 833072, upload-time = "2024-11-06T20:10:52.926Z" }, + { url = "https://files.pythonhosted.org/packages/4f/db/46f563a08f969159c5a0f0e722260568425363bea43bb7ae370becb66a67/regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7", size = 823130, upload-time = "2024-11-06T20:10:54.828Z" }, + { url = "https://files.pythonhosted.org/packages/db/60/1eeca2074f5b87df394fccaa432ae3fc06c9c9bfa97c5051aed70e6e00c2/regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c", size = 796857, upload-time = "2024-11-06T20:10:56.634Z" }, + { url = "https://files.pythonhosted.org/packages/10/db/ac718a08fcee981554d2f7bb8402f1faa7e868c1345c16ab1ebec54b0d7b/regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3", size = 784006, upload-time = "2024-11-06T20:10:59.369Z" }, + { url = "https://files.pythonhosted.org/packages/c2/41/7da3fe70216cea93144bf12da2b87367590bcf07db97604edeea55dac9ad/regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07", size = 781650, upload-time = "2024-11-06T20:11:02.042Z" }, + { url = "https://files.pythonhosted.org/packages/a7/d5/880921ee4eec393a4752e6ab9f0fe28009435417c3102fc413f3fe81c4e5/regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e", size = 789545, upload-time = "2024-11-06T20:11:03.933Z" }, + { url = "https://files.pythonhosted.org/packages/dc/96/53770115e507081122beca8899ab7f5ae28ae790bfcc82b5e38976df6a77/regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6", size = 853045, upload-time = "2024-11-06T20:11:06.497Z" }, + { url = "https://files.pythonhosted.org/packages/31/d3/1372add5251cc2d44b451bd94f43b2ec78e15a6e82bff6a290ef9fd8f00a/regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4", size = 860182, upload-time = "2024-11-06T20:11:09.06Z" }, + { url = "https://files.pythonhosted.org/packages/ed/e3/c446a64984ea9f69982ba1a69d4658d5014bc7a0ea468a07e1a1265db6e2/regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d", size = 787733, upload-time = "2024-11-06T20:11:11.256Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f1/e40c8373e3480e4f29f2692bd21b3e05f296d3afebc7e5dcf21b9756ca1c/regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff", size = 262122, upload-time = "2024-11-06T20:11:13.161Z" }, + { url = "https://files.pythonhosted.org/packages/45/94/bc295babb3062a731f52621cdc992d123111282e291abaf23faa413443ea/regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a", size = 273545, upload-time = "2024-11-06T20:11:15Z" }, +] + +[[package]] +name = "requests" +version = "2.32.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e1/0a/929373653770d8a0d7ea76c37de6e41f11eb07559b103b1c02cafb3f7cf8/requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422", size = 135258, upload-time = "2025-06-09T16:43:07.34Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" }, +] + +[[package]] +name = "safetensors" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/71/7e/2d5d6ee7b40c0682315367ec7475693d110f512922d582fef1bd4a63adc3/safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965", size = 67210, upload-time = "2025-02-26T09:15:13.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/ae/88f6c49dbd0cc4da0e08610019a3c78a7d390879a919411a410a1876d03a/safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073", size = 436917, upload-time = "2025-02-26T09:15:03.702Z" }, + { url = "https://files.pythonhosted.org/packages/b8/3b/11f1b4a2f5d2ab7da34ecc062b0bc301f2be024d110a6466726bec8c055c/safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7", size = 418419, upload-time = "2025-02-26T09:15:01.765Z" }, + { url = "https://files.pythonhosted.org/packages/5d/9a/add3e6fef267658075c5a41573c26d42d80c935cdc992384dfae435feaef/safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467", size = 459493, upload-time = "2025-02-26T09:14:51.812Z" }, + { url = "https://files.pythonhosted.org/packages/df/5c/bf2cae92222513cc23b3ff85c4a1bb2811a2c3583ac0f8e8d502751de934/safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e", size = 472400, upload-time = "2025-02-26T09:14:53.549Z" }, + { url = "https://files.pythonhosted.org/packages/58/11/7456afb740bd45782d0f4c8e8e1bb9e572f1bf82899fb6ace58af47b4282/safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d", size = 522891, upload-time = "2025-02-26T09:14:55.717Z" }, + { url = "https://files.pythonhosted.org/packages/57/3d/fe73a9d2ace487e7285f6e157afee2383bd1ddb911b7cb44a55cf812eae3/safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9", size = 537694, upload-time = "2025-02-26T09:14:57.036Z" }, + { url = "https://files.pythonhosted.org/packages/a6/f8/dae3421624fcc87a89d42e1898a798bc7ff72c61f38973a65d60df8f124c/safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a", size = 471642, upload-time = "2025-02-26T09:15:00.544Z" }, + { url = "https://files.pythonhosted.org/packages/ce/20/1fbe16f9b815f6c5a672f5b760951e20e17e43f67f231428f871909a37f6/safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d", size = 502241, upload-time = "2025-02-26T09:14:58.303Z" }, + { url = "https://files.pythonhosted.org/packages/5f/18/8e108846b506487aa4629fe4116b27db65c3dde922de2c8e0cc1133f3f29/safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b", size = 638001, upload-time = "2025-02-26T09:15:05.79Z" }, + { url = "https://files.pythonhosted.org/packages/82/5a/c116111d8291af6c8c8a8b40628fe833b9db97d8141c2a82359d14d9e078/safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff", size = 734013, upload-time = "2025-02-26T09:15:07.892Z" }, + { url = "https://files.pythonhosted.org/packages/7d/ff/41fcc4d3b7de837963622e8610d998710705bbde9a8a17221d85e5d0baad/safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135", size = 670687, upload-time = "2025-02-26T09:15:09.979Z" }, + { url = "https://files.pythonhosted.org/packages/40/ad/2b113098e69c985a3d8fbda4b902778eae4a35b7d5188859b4a63d30c161/safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04", size = 643147, upload-time = "2025-02-26T09:15:11.185Z" }, + { url = "https://files.pythonhosted.org/packages/0a/0c/95aeb51d4246bd9a3242d3d8349c1112b4ee7611a4b40f0c5c93b05f001d/safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace", size = 296677, upload-time = "2025-02-26T09:15:16.554Z" }, + { url = "https://files.pythonhosted.org/packages/69/e2/b011c38e5394c4c18fb5500778a55ec43ad6106126e74723ffaee246f56e/safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11", size = 308878, upload-time = "2025-02-26T09:15:14.99Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tokenizers" +version = "0.21.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/76/5ac0c97f1117b91b7eb7323dcd61af80d72f790b4df71249a7850c195f30/tokenizers-0.21.1.tar.gz", hash = "sha256:a1bb04dc5b448985f86ecd4b05407f5a8d97cb2c0532199b2a302a604a0165ab", size = 343256, upload-time = "2025-03-13T10:51:18.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/1f/328aee25f9115bf04262e8b4e5a2050b7b7cf44b59c74e982db7270c7f30/tokenizers-0.21.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:e78e413e9e668ad790a29456e677d9d3aa50a9ad311a40905d6861ba7692cf41", size = 2780767, upload-time = "2025-03-13T10:51:09.459Z" }, + { url = "https://files.pythonhosted.org/packages/ae/1a/4526797f3719b0287853f12c5ad563a9be09d446c44ac784cdd7c50f76ab/tokenizers-0.21.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:cd51cd0a91ecc801633829fcd1fda9cf8682ed3477c6243b9a095539de4aecf3", size = 2650555, upload-time = "2025-03-13T10:51:07.692Z" }, + { url = "https://files.pythonhosted.org/packages/4d/7a/a209b29f971a9fdc1da86f917fe4524564924db50d13f0724feed37b2a4d/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28da6b72d4fb14ee200a1bd386ff74ade8992d7f725f2bde2c495a9a98cf4d9f", size = 2937541, upload-time = "2025-03-13T10:50:56.679Z" }, + { url = "https://files.pythonhosted.org/packages/3c/1e/b788b50ffc6191e0b1fc2b0d49df8cff16fe415302e5ceb89f619d12c5bc/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:34d8cfde551c9916cb92014e040806122295a6800914bab5865deb85623931cf", size = 2819058, upload-time = "2025-03-13T10:50:59.525Z" }, + { url = "https://files.pythonhosted.org/packages/36/aa/3626dfa09a0ecc5b57a8c58eeaeb7dd7ca9a37ad9dd681edab5acd55764c/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aaa852d23e125b73d283c98f007e06d4595732104b65402f46e8ef24b588d9f8", size = 3133278, upload-time = "2025-03-13T10:51:04.678Z" }, + { url = "https://files.pythonhosted.org/packages/a4/4d/8fbc203838b3d26269f944a89459d94c858f5b3f9a9b6ee9728cdcf69161/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a21a15d5c8e603331b8a59548bbe113564136dc0f5ad8306dd5033459a226da0", size = 3144253, upload-time = "2025-03-13T10:51:01.261Z" }, + { url = "https://files.pythonhosted.org/packages/d8/1b/2bd062adeb7c7511b847b32e356024980c0ffcf35f28947792c2d8ad2288/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2fdbd4c067c60a0ac7eca14b6bd18a5bebace54eb757c706b47ea93204f7a37c", size = 3398225, upload-time = "2025-03-13T10:51:03.243Z" }, + { url = "https://files.pythonhosted.org/packages/8a/63/38be071b0c8e06840bc6046991636bcb30c27f6bb1e670f4f4bc87cf49cc/tokenizers-0.21.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2dd9a0061e403546f7377df940e866c3e678d7d4e9643d0461ea442b4f89e61a", size = 3038874, upload-time = "2025-03-13T10:51:06.235Z" }, + { url = "https://files.pythonhosted.org/packages/ec/83/afa94193c09246417c23a3c75a8a0a96bf44ab5630a3015538d0c316dd4b/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:db9484aeb2e200c43b915a1a0150ea885e35f357a5a8fabf7373af333dcc8dbf", size = 9014448, upload-time = "2025-03-13T10:51:10.927Z" }, + { url = "https://files.pythonhosted.org/packages/ae/b3/0e1a37d4f84c0f014d43701c11eb8072704f6efe8d8fc2dcdb79c47d76de/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:ed248ab5279e601a30a4d67bdb897ecbe955a50f1e7bb62bd99f07dd11c2f5b6", size = 8937877, upload-time = "2025-03-13T10:51:12.688Z" }, + { url = "https://files.pythonhosted.org/packages/ac/33/ff08f50e6d615eb180a4a328c65907feb6ded0b8f990ec923969759dc379/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:9ac78b12e541d4ce67b4dfd970e44c060a2147b9b2a21f509566d556a509c67d", size = 9186645, upload-time = "2025-03-13T10:51:14.723Z" }, + { url = "https://files.pythonhosted.org/packages/5f/aa/8ae85f69a9f6012c6f8011c6f4aa1c96154c816e9eea2e1b758601157833/tokenizers-0.21.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e5a69c1a4496b81a5ee5d2c1f3f7fbdf95e90a0196101b0ee89ed9956b8a168f", size = 9384380, upload-time = "2025-03-13T10:51:16.526Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5b/a5d98c89f747455e8b7a9504910c865d5e51da55e825a7ae641fb5ff0a58/tokenizers-0.21.1-cp39-abi3-win32.whl", hash = "sha256:1039a3a5734944e09de1d48761ade94e00d0fa760c0e0551151d4dd851ba63e3", size = 2239506, upload-time = "2025-03-13T10:51:20.643Z" }, + { url = "https://files.pythonhosted.org/packages/e6/b6/072a8e053ae600dcc2ac0da81a23548e3b523301a442a6ca900e92ac35be/tokenizers-0.21.1-cp39-abi3-win_amd64.whl", hash = "sha256:0f0dcbcc9f6e13e675a66d7a5f2f225a736745ce484c1a4e07476a89ccdad382", size = 2435481, upload-time = "2025-03-13T10:51:19.243Z" }, +] + +[[package]] +name = "torch" +version = "2.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/81/e48c9edb655ee8eb8c2a6026abdb6f8d2146abd1f150979ede807bb75dcb/torch-2.7.1-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:03563603d931e70722dce0e11999d53aa80a375a3d78e6b39b9f6805ea0a8d28", size = 98946649, upload-time = "2025-06-04T17:38:43.031Z" }, + { url = "https://files.pythonhosted.org/packages/3a/24/efe2f520d75274fc06b695c616415a1e8a1021d87a13c68ff9dce733d088/torch-2.7.1-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:d632f5417b6980f61404a125b999ca6ebd0b8b4bbdbb5fbbba44374ab619a412", size = 821033192, upload-time = "2025-06-04T17:38:09.146Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d9/9c24d230333ff4e9b6807274f6f8d52a864210b52ec794c5def7925f4495/torch-2.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:23660443e13995ee93e3d844786701ea4ca69f337027b05182f5ba053ce43b38", size = 216055668, upload-time = "2025-06-04T17:38:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/95/bf/e086ee36ddcef9299f6e708d3b6c8487c1651787bb9ee2939eb2a7f74911/torch-2.7.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:0da4f4dba9f65d0d203794e619fe7ca3247a55ffdcbd17ae8fb83c8b2dc9b585", size = 68925988, upload-time = "2025-06-04T17:38:29.273Z" }, + { url = "https://files.pythonhosted.org/packages/69/6a/67090dcfe1cf9048448b31555af6efb149f7afa0a310a366adbdada32105/torch-2.7.1-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:e08d7e6f21a617fe38eeb46dd2213ded43f27c072e9165dc27300c9ef9570934", size = 99028857, upload-time = "2025-06-04T17:37:50.956Z" }, + { url = "https://files.pythonhosted.org/packages/90/1c/48b988870823d1cc381f15ec4e70ed3d65e043f43f919329b0045ae83529/torch-2.7.1-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:30207f672328a42df4f2174b8f426f354b2baa0b7cca3a0adb3d6ab5daf00dc8", size = 821098066, upload-time = "2025-06-04T17:37:33.939Z" }, + { url = "https://files.pythonhosted.org/packages/7b/eb/10050d61c9d5140c5dc04a89ed3257ef1a6b93e49dd91b95363d757071e0/torch-2.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:79042feca1c634aaf6603fe6feea8c6b30dfa140a6bbc0b973e2260c7e79a22e", size = 216336310, upload-time = "2025-06-04T17:36:09.862Z" }, + { url = "https://files.pythonhosted.org/packages/b1/29/beb45cdf5c4fc3ebe282bf5eafc8dfd925ead7299b3c97491900fe5ed844/torch-2.7.1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:988b0cbc4333618a1056d2ebad9eb10089637b659eb645434d0809d8d937b946", size = 68645708, upload-time = "2025-06-04T17:34:39.852Z" }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + +[[package]] +name = "transformers" +version = "4.52.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/a9/275037087f9d846580b02f2d7cae0e0a6955d46f84583d0151d6227bd416/transformers-4.52.4.tar.gz", hash = "sha256:aff3764441c1adc192a08dba49740d3cbbcb72d850586075aed6bd89b98203e6", size = 8945376, upload-time = "2025-05-30T09:17:17.947Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/f2/25b27b396af03d5b64e61976b14f7209e2939e9e806c10749b6d277c273e/transformers-4.52.4-py3-none-any.whl", hash = "sha256:203f5c19416d5877e36e88633943761719538a25d9775977a24fe77a1e5adfc7", size = 10460375, upload-time = "2025-05-30T09:17:14.477Z" }, +] + +[[package]] +name = "triton" +version = "3.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/1f/dfb531f90a2d367d914adfee771babbd3f1a5b26c3f5fbc458dee21daa78/triton-3.3.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b89d846b5a4198317fec27a5d3a609ea96b6d557ff44b56c23176546023c4240", size = 155673035, upload-time = "2025-05-29T23:40:02.468Z" }, + { url = "https://files.pythonhosted.org/packages/28/71/bd20ffcb7a64c753dc2463489a61bf69d531f308e390ad06390268c4ea04/triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42", size = 155735832, upload-time = "2025-05-29T23:40:10.522Z" }, +] + +[[package]] +name = "trl" +version = "0.0.0" +source = { virtual = "." } +dependencies = [ + { name = "accelerate" }, + { name = "datasets" }, + { name = "deepspeed" }, + { name = "peft" }, + { name = "transformers" }, +] + +[package.metadata] +requires-dist = [ + { name = "accelerate", specifier = ">=1.7.0" }, + { name = "datasets", specifier = ">=3.6.0" }, + { name = "deepspeed", specifier = ">=0.17.1" }, + { name = "peft", specifier = ">=0.15.2" }, + { name = "transformers", specifier = ">=4.52.4" }, +] + +[[package]] +name = "typing-extensions" +version = "4.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/bc/51647cd02527e87d05cb083ccc402f93e441606ff1f01739a62c8ad09ba5/typing_extensions-4.14.0.tar.gz", hash = "sha256:8676b788e32f02ab42d9e7c61324048ae4c6d844a399eebace3d4979d75ceef4", size = 107423, upload-time = "2025-06-02T14:52:11.399Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/e0/552843e0d356fbb5256d21449fa957fa4eff3bbc135a74a691ee70c7c5da/typing_extensions-4.14.0-py3-none-any.whl", hash = "sha256:a1514509136dd0b477638fc68d6a91497af5076466ad0fa6c338e44e359944af", size = 43839, upload-time = "2025-06-02T14:52:10.026Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, +] + +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + +[[package]] +name = "urllib3" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/78/16493d9c386d8e60e442a35feac5e00f0913c0f4b7c217c11e8ec2ff53e0/urllib3-2.4.0.tar.gz", hash = "sha256:414bc6535b787febd7567804cc015fee39daab8ad86268f1310a9250697de466", size = 390672, upload-time = "2025-04-10T15:23:39.232Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/11/cc635220681e93a0183390e26485430ca2c7b5f9d33b15c74c2861cb8091/urllib3-2.4.0-py3-none-any.whl", hash = "sha256:4e16665048960a0900c702d4a66415956a584919c03361cac9f1df5c5dd7e813", size = 128680, upload-time = "2025-04-10T15:23:37.377Z" }, +] + +[[package]] +name = "xxhash" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/00/5e/d6e5258d69df8b4ed8c83b6664f2b47d30d2dec551a29ad72a6c69eafd31/xxhash-3.5.0.tar.gz", hash = "sha256:84f2caddf951c9cbf8dc2e22a89d4ccf5d86391ac6418fe81e3c67d0cf60b45f", size = 84241, upload-time = "2024-08-17T09:20:38.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/b8/e4b3ad92d249be5c83fa72916c9091b0965cb0faeff05d9a0a3870ae6bff/xxhash-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37889a0d13b0b7d739cfc128b1c902f04e32de17b33d74b637ad42f1c55101f6", size = 31795, upload-time = "2024-08-17T09:18:46.813Z" }, + { url = "https://files.pythonhosted.org/packages/fc/d8/b3627a0aebfbfa4c12a41e22af3742cf08c8ea84f5cc3367b5de2d039cce/xxhash-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97a662338797c660178e682f3bc180277b9569a59abfb5925e8620fba00b9fc5", size = 30792, upload-time = "2024-08-17T09:18:47.862Z" }, + { url = "https://files.pythonhosted.org/packages/c3/cc/762312960691da989c7cd0545cb120ba2a4148741c6ba458aa723c00a3f8/xxhash-3.5.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f85e0108d51092bdda90672476c7d909c04ada6923c14ff9d913c4f7dc8a3bc", size = 220950, upload-time = "2024-08-17T09:18:49.06Z" }, + { url = "https://files.pythonhosted.org/packages/fe/e9/cc266f1042c3c13750e86a535496b58beb12bf8c50a915c336136f6168dc/xxhash-3.5.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd2fd827b0ba763ac919440042302315c564fdb797294d86e8cdd4578e3bc7f3", size = 199980, upload-time = "2024-08-17T09:18:50.445Z" }, + { url = "https://files.pythonhosted.org/packages/bf/85/a836cd0dc5cc20376de26b346858d0ac9656f8f730998ca4324921a010b9/xxhash-3.5.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:82085c2abec437abebf457c1d12fccb30cc8b3774a0814872511f0f0562c768c", size = 428324, upload-time = "2024-08-17T09:18:51.988Z" }, + { url = "https://files.pythonhosted.org/packages/b4/0e/15c243775342ce840b9ba34aceace06a1148fa1630cd8ca269e3223987f5/xxhash-3.5.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07fda5de378626e502b42b311b049848c2ef38784d0d67b6f30bb5008642f8eb", size = 194370, upload-time = "2024-08-17T09:18:54.164Z" }, + { url = "https://files.pythonhosted.org/packages/87/a1/b028bb02636dfdc190da01951d0703b3d904301ed0ef6094d948983bef0e/xxhash-3.5.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c279f0d2b34ef15f922b77966640ade58b4ccdfef1c4d94b20f2a364617a493f", size = 207911, upload-time = "2024-08-17T09:18:55.509Z" }, + { url = "https://files.pythonhosted.org/packages/80/d5/73c73b03fc0ac73dacf069fdf6036c9abad82de0a47549e9912c955ab449/xxhash-3.5.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89e66ceed67b213dec5a773e2f7a9e8c58f64daeb38c7859d8815d2c89f39ad7", size = 216352, upload-time = "2024-08-17T09:18:57.073Z" }, + { url = "https://files.pythonhosted.org/packages/b6/2a/5043dba5ddbe35b4fe6ea0a111280ad9c3d4ba477dd0f2d1fe1129bda9d0/xxhash-3.5.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bcd51708a633410737111e998ceb3b45d3dbc98c0931f743d9bb0a209033a326", size = 203410, upload-time = "2024-08-17T09:18:58.54Z" }, + { url = "https://files.pythonhosted.org/packages/a2/b2/9a8ded888b7b190aed75b484eb5c853ddd48aa2896e7b59bbfbce442f0a1/xxhash-3.5.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:3ff2c0a34eae7df88c868be53a8dd56fbdf592109e21d4bfa092a27b0bf4a7bf", size = 210322, upload-time = "2024-08-17T09:18:59.943Z" }, + { url = "https://files.pythonhosted.org/packages/98/62/440083fafbc917bf3e4b67c2ade621920dd905517e85631c10aac955c1d2/xxhash-3.5.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4e28503dccc7d32e0b9817aa0cbfc1f45f563b2c995b7a66c4c8a0d232e840c7", size = 414725, upload-time = "2024-08-17T09:19:01.332Z" }, + { url = "https://files.pythonhosted.org/packages/75/db/009206f7076ad60a517e016bb0058381d96a007ce3f79fa91d3010f49cc2/xxhash-3.5.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a6c50017518329ed65a9e4829154626f008916d36295b6a3ba336e2458824c8c", size = 192070, upload-time = "2024-08-17T09:19:03.007Z" }, + { url = "https://files.pythonhosted.org/packages/1f/6d/c61e0668943a034abc3a569cdc5aeae37d686d9da7e39cf2ed621d533e36/xxhash-3.5.0-cp313-cp313-win32.whl", hash = "sha256:53a068fe70301ec30d868ece566ac90d873e3bb059cf83c32e76012c889b8637", size = 30172, upload-time = "2024-08-17T09:19:04.355Z" }, + { url = "https://files.pythonhosted.org/packages/96/14/8416dce965f35e3d24722cdf79361ae154fa23e2ab730e5323aa98d7919e/xxhash-3.5.0-cp313-cp313-win_amd64.whl", hash = "sha256:80babcc30e7a1a484eab952d76a4f4673ff601f54d5142c26826502740e70b43", size = 30041, upload-time = "2024-08-17T09:19:05.435Z" }, + { url = "https://files.pythonhosted.org/packages/27/ee/518b72faa2073f5aa8e3262408d284892cb79cf2754ba0c3a5870645ef73/xxhash-3.5.0-cp313-cp313-win_arm64.whl", hash = "sha256:4811336f1ce11cac89dcbd18f3a25c527c16311709a89313c3acaf771def2d4b", size = 26801, upload-time = "2024-08-17T09:19:06.547Z" }, +] + +[[package]] +name = "yarl" +version = "1.20.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3c/fb/efaa23fa4e45537b827620f04cf8f3cd658b76642205162e072703a5b963/yarl-1.20.1.tar.gz", hash = "sha256:d017a4997ee50c91fd5466cef416231bb82177b93b029906cefc542ce14c35ac", size = 186428, upload-time = "2025-06-10T00:46:09.923Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/e1/2411b6d7f769a07687acee88a062af5833cf1966b7266f3d8dfb3d3dc7d3/yarl-1.20.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:0b5ff0fbb7c9f1b1b5ab53330acbfc5247893069e7716840c8e7d5bb7355038a", size = 131811, upload-time = "2025-06-10T00:44:18.933Z" }, + { url = "https://files.pythonhosted.org/packages/b2/27/584394e1cb76fb771371770eccad35de400e7b434ce3142c2dd27392c968/yarl-1.20.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:14f326acd845c2b2e2eb38fb1346c94f7f3b01a4f5c788f8144f9b630bfff9a3", size = 90078, upload-time = "2025-06-10T00:44:20.635Z" }, + { url = "https://files.pythonhosted.org/packages/bf/9a/3246ae92d4049099f52d9b0fe3486e3b500e29b7ea872d0f152966fc209d/yarl-1.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f60e4ad5db23f0b96e49c018596707c3ae89f5d0bd97f0ad3684bcbad899f1e7", size = 88748, upload-time = "2025-06-10T00:44:22.34Z" }, + { url = "https://files.pythonhosted.org/packages/a3/25/35afe384e31115a1a801fbcf84012d7a066d89035befae7c5d4284df1e03/yarl-1.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:49bdd1b8e00ce57e68ba51916e4bb04461746e794e7c4d4bbc42ba2f18297691", size = 349595, upload-time = "2025-06-10T00:44:24.314Z" }, + { url = "https://files.pythonhosted.org/packages/28/2d/8aca6cb2cabc8f12efcb82749b9cefecbccfc7b0384e56cd71058ccee433/yarl-1.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:66252d780b45189975abfed839616e8fd2dbacbdc262105ad7742c6ae58f3e31", size = 342616, upload-time = "2025-06-10T00:44:26.167Z" }, + { url = "https://files.pythonhosted.org/packages/0b/e9/1312633d16b31acf0098d30440ca855e3492d66623dafb8e25b03d00c3da/yarl-1.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:59174e7332f5d153d8f7452a102b103e2e74035ad085f404df2e40e663a22b28", size = 361324, upload-time = "2025-06-10T00:44:27.915Z" }, + { url = "https://files.pythonhosted.org/packages/bc/a0/688cc99463f12f7669eec7c8acc71ef56a1521b99eab7cd3abb75af887b0/yarl-1.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e3968ec7d92a0c0f9ac34d5ecfd03869ec0cab0697c91a45db3fbbd95fe1b653", size = 359676, upload-time = "2025-06-10T00:44:30.041Z" }, + { url = "https://files.pythonhosted.org/packages/af/44/46407d7f7a56e9a85a4c207724c9f2c545c060380718eea9088f222ba697/yarl-1.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1a4fbb50e14396ba3d375f68bfe02215d8e7bc3ec49da8341fe3157f59d2ff5", size = 352614, upload-time = "2025-06-10T00:44:32.171Z" }, + { url = "https://files.pythonhosted.org/packages/b1/91/31163295e82b8d5485d31d9cf7754d973d41915cadce070491778d9c9825/yarl-1.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:11a62c839c3a8eac2410e951301309426f368388ff2f33799052787035793b02", size = 336766, upload-time = "2025-06-10T00:44:34.494Z" }, + { url = "https://files.pythonhosted.org/packages/b4/8e/c41a5bc482121f51c083c4c2bcd16b9e01e1cf8729e380273a952513a21f/yarl-1.20.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:041eaa14f73ff5a8986b4388ac6bb43a77f2ea09bf1913df7a35d4646db69e53", size = 364615, upload-time = "2025-06-10T00:44:36.856Z" }, + { url = "https://files.pythonhosted.org/packages/e3/5b/61a3b054238d33d70ea06ebba7e58597891b71c699e247df35cc984ab393/yarl-1.20.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:377fae2fef158e8fd9d60b4c8751387b8d1fb121d3d0b8e9b0be07d1b41e83dc", size = 360982, upload-time = "2025-06-10T00:44:39.141Z" }, + { url = "https://files.pythonhosted.org/packages/df/a3/6a72fb83f8d478cb201d14927bc8040af901811a88e0ff2da7842dd0ed19/yarl-1.20.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:1c92f4390e407513f619d49319023664643d3339bd5e5a56a3bebe01bc67ec04", size = 369792, upload-time = "2025-06-10T00:44:40.934Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/4cc3c36dfc7c077f8dedb561eb21f69e1e9f2456b91b593882b0b18c19dc/yarl-1.20.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d25ddcf954df1754ab0f86bb696af765c5bfaba39b74095f27eececa049ef9a4", size = 382049, upload-time = "2025-06-10T00:44:42.854Z" }, + { url = "https://files.pythonhosted.org/packages/19/3a/e54e2c4752160115183a66dc9ee75a153f81f3ab2ba4bf79c3c53b33de34/yarl-1.20.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:909313577e9619dcff8c31a0ea2aa0a2a828341d92673015456b3ae492e7317b", size = 384774, upload-time = "2025-06-10T00:44:45.275Z" }, + { url = "https://files.pythonhosted.org/packages/9c/20/200ae86dabfca89060ec6447649f219b4cbd94531e425e50d57e5f5ac330/yarl-1.20.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:793fd0580cb9664548c6b83c63b43c477212c0260891ddf86809e1c06c8b08f1", size = 374252, upload-time = "2025-06-10T00:44:47.31Z" }, + { url = "https://files.pythonhosted.org/packages/83/75/11ee332f2f516b3d094e89448da73d557687f7d137d5a0f48c40ff211487/yarl-1.20.1-cp313-cp313-win32.whl", hash = "sha256:468f6e40285de5a5b3c44981ca3a319a4b208ccc07d526b20b12aeedcfa654b7", size = 81198, upload-time = "2025-06-10T00:44:49.164Z" }, + { url = "https://files.pythonhosted.org/packages/ba/ba/39b1ecbf51620b40ab402b0fc817f0ff750f6d92712b44689c2c215be89d/yarl-1.20.1-cp313-cp313-win_amd64.whl", hash = "sha256:495b4ef2fea40596bfc0affe3837411d6aa3371abcf31aac0ccc4bdd64d4ef5c", size = 86346, upload-time = "2025-06-10T00:44:51.182Z" }, + { url = "https://files.pythonhosted.org/packages/43/c7/669c52519dca4c95153c8ad96dd123c79f354a376346b198f438e56ffeb4/yarl-1.20.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:f60233b98423aab21d249a30eb27c389c14929f47be8430efa7dbd91493a729d", size = 138826, upload-time = "2025-06-10T00:44:52.883Z" }, + { url = "https://files.pythonhosted.org/packages/6a/42/fc0053719b44f6ad04a75d7f05e0e9674d45ef62f2d9ad2c1163e5c05827/yarl-1.20.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:6f3eff4cc3f03d650d8755c6eefc844edde99d641d0dcf4da3ab27141a5f8ddf", size = 93217, upload-time = "2025-06-10T00:44:54.658Z" }, + { url = "https://files.pythonhosted.org/packages/4f/7f/fa59c4c27e2a076bba0d959386e26eba77eb52ea4a0aac48e3515c186b4c/yarl-1.20.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:69ff8439d8ba832d6bed88af2c2b3445977eba9a4588b787b32945871c2444e3", size = 92700, upload-time = "2025-06-10T00:44:56.784Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d4/062b2f48e7c93481e88eff97a6312dca15ea200e959f23e96d8ab898c5b8/yarl-1.20.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3cf34efa60eb81dd2645a2e13e00bb98b76c35ab5061a3989c7a70f78c85006d", size = 347644, upload-time = "2025-06-10T00:44:59.071Z" }, + { url = "https://files.pythonhosted.org/packages/89/47/78b7f40d13c8f62b499cc702fdf69e090455518ae544c00a3bf4afc9fc77/yarl-1.20.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:8e0fe9364ad0fddab2688ce72cb7a8e61ea42eff3c7caeeb83874a5d479c896c", size = 323452, upload-time = "2025-06-10T00:45:01.605Z" }, + { url = "https://files.pythonhosted.org/packages/eb/2b/490d3b2dc66f52987d4ee0d3090a147ea67732ce6b4d61e362c1846d0d32/yarl-1.20.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8f64fbf81878ba914562c672024089e3401974a39767747691c65080a67b18c1", size = 346378, upload-time = "2025-06-10T00:45:03.946Z" }, + { url = "https://files.pythonhosted.org/packages/66/ad/775da9c8a94ce925d1537f939a4f17d782efef1f973039d821cbe4bcc211/yarl-1.20.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6342d643bf9a1de97e512e45e4b9560a043347e779a173250824f8b254bd5ce", size = 353261, upload-time = "2025-06-10T00:45:05.992Z" }, + { url = "https://files.pythonhosted.org/packages/4b/23/0ed0922b47a4f5c6eb9065d5ff1e459747226ddce5c6a4c111e728c9f701/yarl-1.20.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56dac5f452ed25eef0f6e3c6a066c6ab68971d96a9fb441791cad0efba6140d3", size = 335987, upload-time = "2025-06-10T00:45:08.227Z" }, + { url = "https://files.pythonhosted.org/packages/3e/49/bc728a7fe7d0e9336e2b78f0958a2d6b288ba89f25a1762407a222bf53c3/yarl-1.20.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7d7f497126d65e2cad8dc5f97d34c27b19199b6414a40cb36b52f41b79014be", size = 329361, upload-time = "2025-06-10T00:45:10.11Z" }, + { url = "https://files.pythonhosted.org/packages/93/8f/b811b9d1f617c83c907e7082a76e2b92b655400e61730cd61a1f67178393/yarl-1.20.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:67e708dfb8e78d8a19169818eeb5c7a80717562de9051bf2413aca8e3696bf16", size = 346460, upload-time = "2025-06-10T00:45:12.055Z" }, + { url = "https://files.pythonhosted.org/packages/70/fd/af94f04f275f95da2c3b8b5e1d49e3e79f1ed8b6ceb0f1664cbd902773ff/yarl-1.20.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:595c07bc79af2494365cc96ddeb772f76272364ef7c80fb892ef9d0649586513", size = 334486, upload-time = "2025-06-10T00:45:13.995Z" }, + { url = "https://files.pythonhosted.org/packages/84/65/04c62e82704e7dd0a9b3f61dbaa8447f8507655fd16c51da0637b39b2910/yarl-1.20.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7bdd2f80f4a7df852ab9ab49484a4dee8030023aa536df41f2d922fd57bf023f", size = 342219, upload-time = "2025-06-10T00:45:16.479Z" }, + { url = "https://files.pythonhosted.org/packages/91/95/459ca62eb958381b342d94ab9a4b6aec1ddec1f7057c487e926f03c06d30/yarl-1.20.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:c03bfebc4ae8d862f853a9757199677ab74ec25424d0ebd68a0027e9c639a390", size = 350693, upload-time = "2025-06-10T00:45:18.399Z" }, + { url = "https://files.pythonhosted.org/packages/a6/00/d393e82dd955ad20617abc546a8f1aee40534d599ff555ea053d0ec9bf03/yarl-1.20.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:344d1103e9c1523f32a5ed704d576172d2cabed3122ea90b1d4e11fe17c66458", size = 355803, upload-time = "2025-06-10T00:45:20.677Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ed/c5fb04869b99b717985e244fd93029c7a8e8febdfcffa06093e32d7d44e7/yarl-1.20.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:88cab98aa4e13e1ade8c141daeedd300a4603b7132819c484841bb7af3edce9e", size = 341709, upload-time = "2025-06-10T00:45:23.221Z" }, + { url = "https://files.pythonhosted.org/packages/24/fd/725b8e73ac2a50e78a4534ac43c6addf5c1c2d65380dd48a9169cc6739a9/yarl-1.20.1-cp313-cp313t-win32.whl", hash = "sha256:b121ff6a7cbd4abc28985b6028235491941b9fe8fe226e6fdc539c977ea1739d", size = 86591, upload-time = "2025-06-10T00:45:25.793Z" }, + { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, + { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, +]