File size: 6,926 Bytes
5e1b2e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Download Model Checkpoints
This document describes how to use the provided Python script to download model checkpoints from HuggingFace or GitHub, based on a YAML configuration file.
## Overview
The script allows users to download machine learning model checkpoints from platforms like HuggingFace or GitHub. It supports downloading either full model checkpoints or specific files for inference, with options to filter based on whether the model is a base model or intended for inference. The configuration is provided via a YAML file, and the script uses command-line arguments for additional flexibility.
## Requirements
- Python 3.x
- Required Python packages:
- `argparse` (standard library)
- `yaml` (install via `pip install pyyaml`)
- `os` (standard library)
- `requests` (install via `pip install requests`)
- `huggingface_hub` (install via `pip install huggingface_hub`)
## Usage
Run the script from the command line with the following options:
```bash
python script_name.py [--config CONFIG_PATH] [--full_ckpts] [--include_base_model] [--base_model_only]
```
### Command-Line Arguments
| Argument | Description | Default Value |
|------------------------|-----------------------------------------------------------------------------|-----------------------------------|
| `--config` | Path to the YAML configuration file. | `configs/model_ckpts.yaml` |
| `--full_ckpts` | If specified, downloads all models using `snapshot_download`. Otherwise, downloads only specific files for models marked `for_inference` in the YAML. | Not set (False) |
| `--include_base_model` | If specified, downloads all models (both `base_model: true` and `false`). Otherwise, skips models with `base_model: true`. | Not set (False) |
| `--base_model_only` | If specified, downloads only models with `base_model: true`, ignoring `for_inference`. | Not set (False) |
### YAML Configuration
The script expects a YAML configuration file (default: `configs/model_ckpts.yaml`) with a list of model configurations. Each model configuration should include:
| Key | Description | Required | Example Value |
|----------------|-----------------------------------------------------------------------------|----------|-----------------------------------|
| `model_id` | Identifier for the model (e.g., HuggingFace repository ID or model name). | Yes | `bert-base-uncased` |
| `local_dir` | Local directory to save the downloaded model. | Yes | `./models/bert` |
| `platform` | Platform to download from (`HuggingFace` or `GitHub`). | Yes | `HuggingFace` |
| `url` | Direct URL to the model file (required for GitHub platform). | No | `https://github.com/.../model.pth` |
| `filename` | Specific file to download (required for HuggingFace with `--full_ckpts` not set, or GitHub). | No | `pytorch_model.bin` |
| `base_model` | Boolean indicating if the model is a base model. | No | `true` or `false` |
| `for_inference`| Boolean indicating if the model is used for inference. | No | `true` or `false` |
#### Example YAML Configuration
```yaml
- model_id: bert-base-uncased
local_dir: ./models/bert
platform: HuggingFace
filename: pytorch_model.bin
for_inference: true
base_model: false
- model_id: custom-model
local_dir: ./models/custom
platform: GitHub
url: https://github.com/user/repo/releases/download/v1.0/model.pth
filename: model.pth
for_inference: false
base_model: true
```
## Script Workflow
1. **Parse Command-Line Arguments**:
- The script uses `argparse` to handle command-line arguments for configuration file path and download options.
2. **Load YAML Configuration**:
- Reads the specified YAML file to load model configurations using the `load_config` function.
3. **Filter Models**:
- Based on the command-line arguments, the script filters models to download:
- If `--base_model_only` is set, only models with `base_model: true` are downloaded.
- If `--full_ckpts` is not set, only models with `for_inference: true` are downloaded.
- If `--include_base_model` is not set, models with `base_model: true` are skipped unless explicitly included via `--base_model_only`.
4. **Download Models**:
- For each model in the configuration that passes the filters, the `download_model` function is called.
- **HuggingFace**:
- If `--full_ckpts` is set, uses `snapshot_download` to download the entire model repository (files with `.pth`, `.bin`, or `.json` extensions).
- Otherwise, uses `hf_hub_download` to download a specific file specified by `filename`.
- **GitHub**:
- Downloads the file from the provided `url` using `requests` and saves it to `local_dir` with the specified `filename` (or derived from the URL if not provided).
5. **Error Handling**:
- Ensures the local directory exists before downloading.
- Validates required fields (`url` for GitHub, `filename` for non-full HuggingFace downloads).
- Raises errors for unsupported platforms or failed downloads (e.g., HTTP errors).
## Example Commands
1. **Download specific inference files** (default behavior):
```bash
python script_name.py --config configs/model_ckpts.yaml
```
Downloads only models with `for_inference: true` and `base_model: false`.
2. **Download full model checkpoints**:
```bash
python script_name.py --config configs/model_ckpts.yaml --full_ckpts
```
Downloads full checkpoints for models with `for_inference: true` and `base_model: false`.
3. **Include base models**:
```bash
python script_name.py --config configs/model_ckpts.yaml --include_base_model
```
Downloads models regardless of `base_model` status, but still respects `for_inference` unless `--full_ckpts` is set.
4. **Download only base models**:
```bash
python script_name.py --config configs/model_ckpts.yaml --base_model_only
```
Downloads only models with `base_model: true`, ignoring `for_inference`.
## Notes
- Ensure the YAML configuration file is correctly formatted and accessible.
- For HuggingFace downloads, an internet connection and sufficient disk space are required.
- For GitHub downloads, the provided URL must be a direct link to a downloadable file.
- The script creates the `local_dir` if it does not exist.
- Downloaded files are saved to the specified `local_dir` with their original filenames (or as specified in the YAML). |