File size: 10,978 Bytes
65bd8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
# Simple Guidance Mechanisms for Discrete Diffusion Models
[](https://arxiv.org/abs/2412.10193)
[](https://discrete-diffusion-guidance.github.io/)
[](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b)
<p align="center">
<img src="https://discrete-diffusion-guidance.github.io/static/images/udlm.gif" alt="graphical abstract" width="450"/>
</p>
This repository contains code for reproducing experiments in the paper [Simple Guidance Mechanisms for Discrete Diffusion Models](https://arxiv.org/abs/2412.10193)
We also share [trained models](https://huggingface.co/collections/kuleshov-group/udlm-675e63ab42bc757093099e1b) on HuggingFace 🤗 and support intergration with these models.
See the "[Using HuggingFace Models" section](#using-huggingface-models) below.
## Code Organization
<a name="code-organization"></a>
1. ```main.py```: Routines for training (language models and classifiers)
2. ```noise_schedule.py```: Noise schedules
3. ```diffusion.py```: Forward/reverse diffusion
- Absorbing state / uniform noise diffusion
- AR
4. ```dataloader.py```: Dataloaders
- For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in ```custom_datasets/```
5. ```utils.py```: LR scheduler, logging, `fsspec` handling
6. ```models/```: Denoising network architectures.
7. ```configs/```: Config files for datasets/denoising networks/noise schedules/LR schedules
8. ```scripts/```: Shell scripts for training/evaluation
9. ```guidance_eval/```: Guidance evaluation scripts
### Implemented Decoding Mechanisms
<a name="implemented-decoding"></a>
In [`diffusion.py`](./diffusion.py),
we define baseline and proposed decoding mechanisms for guidance.
These decoding schemes can be controlled via the hydra config with the `guidance` field.
For example, to use the proposed D-CFG guidance mechanism,
set `guidance=cfg` in the config file and optionally set the `guidance.gamma` parameter to control the strength of the guidance signal.
The implemented decoding methods are as follows:
- AR (Baseline):
- Standard decoding (i.e., no-guidance); set `guidance=null`
- Classifier-free guidance (D-CFG); set `guidance=cfg`
- Classifier-based guidance using [FUDGE](https://arxiv.org/abs/2104.05218) (set `guidance=fudge`) and using [PPLM](https://arxiv.org/abs/1912.02164) (set `guidance=pplm`)
- Diffusion:
- Standard decoding (i.e., no guidance); set `guidance=null`
- Classifier-free guidance (D-CFG); set `guidance=cfg`
- Classifier-based guidance (D-CBG); set `guidance=cbg`
- Classifier-based (baseline) method of [NOS](https://arxiv.org/abs/2305.20009); set `guidance=nos`
### Implemented Generative Models
<a name="implemented-models"></a>
The three modeling parameterizations
we explore in this work are:
1. Autoregressive (AR) Models
2. Masked Diffusion Language Models (MDLM)
3. Uniform Diffusion Language Models (UDLM)
The `config` files can be used
to specify which of these parameterizations to use.
Below we detail which config parameters correspond to which model.
**AR**
```bash
diffusion="absorbing_state" # AR models can be thought of as a special case of abosrbing state diffusion models
parameterization="ar"
T=0 # N/A for AR models, this is a placeholder
time_conditioning=False # AR models are not conditioned on time
zero_recon_loss=False # N/A for this model
```
**MDLM**
```bash
diffusion="absorbing_state"
parameterization="subs" # See MDLM paper for details: https://arxiv.org/abs/2406.07524
T=0 # Indicates continuous-time, e.g. T --> infinity
time_conditioning=False # MDLM not conditioned on time
zero_recon_loss=False # N/A for this model
```
**UDLM**
```bash
diffusion="uniform"
parameterization="d3pm" # Indicates that we explicitly compute KL on posteriors
T=0 # Indicates continuous-time, e.g. T --> infinity
time_conditioning=True # UDLM is conditioned on time
zero_recon_loss=True # In continuous time, recon loss evaluates to zero
```
## Getting started in this repository
<a name="getting-started"></a>
To get started, create a conda environment containing the required dependencies.
```bash
conda env create -f requirements.yaml
conda activate discdiff
```
Create the following directories to store saved models and slurm logs:
```bash
mkdir outputs
mkdir watch_folder
```
We rely on `wandb` integration
to log experiments and eval curves.
## Reproducing Experiments
<a name="reproducing-experiments"></a>
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the [`main.py`](./main.py) script.
We also provide sample `slurm` scripts for launching pre-training and evaluation experiments in the [`scrips/`](./scripts) directory.
### Language Modeling Experiments
<a name="lm_training"></a>
To reproduce the language modeling results, please refer to the following shell scripts in the [`scripts/`](./scripts) directory:
- Species10: [`train_ten_species_guidance.sh`](./scripts/train_ten_species_guidance.sh)
- QM9: [`train_qm9_no-guidance.sh`](./scripts/train_qm9_no-guidance.sh)
- CIFAR10: [`train_cifar10_unet_guidance.sh`](./scripts/train_cifar10_unet_guidance.sh)
- text8: [`train_text8.sh`](./scripts/train_text8.sh)
- Amazon Polarity: [`train_amazon_polarity.sh`](./scripts/train_amazon_polarity.sh)
- LM1B: [`train_lm1b.sh`](./scripts/train_lm1b.sh)
Each script contains a comment detailing the usage.
For example, to train either an AR,
MDLM, or UDLM model on the `text8` dataset, use the following command:
```bash
cd scripts/
MODEL=<ar|mdlm|udlm>
sbatch \
--export=ALL,MODEL=${MODEL} \
--job-name=train_text8_${MODEL} \
train_text8.sh
```
### Guidance Training
<a name="guidance-training"></a>
#### Classifier-Free
<a name="guidance-training-cfg"></a>
For classifier-free guidance we require training models
that can condition on the class label
to model conditional distributions,
and we randomly mask out the signal,
replacing it with a dummy value of `num_claseses + 1`, to simulate an unconditional model.
Refer to the shell scripts with the `_guidance` suffix
to train these models for CIFAR10,
QM9, and Species10 datasets.
For QM9, we have two experiments,
one where we condition on the drug-likeness
(`qed`)
of the molecules and another
where we condition on the ring counts (`ring_count`).
#### Classifier-Based
<a name="guidance-training-cbg"></a>
For classifier-based guidance,
we need to train a classifier on the noisy latent samples.
Refer to the following shell scripts
to train these classifiers:
- [FUDGE](https://arxiv.org/abs/2104.05218) (AR guidance): [`train_qm9_fudge_classifier.sh`](./scripts/train_qm9_fudge_classifier.sh)
- D-CBG (diffusion guidance): [`train_qm9_classifier.sh`](./scripts/train_qm9_classifier.sh)
##### PPLM / NOS baselines
An alternative classifier-based guidance mechanism to D-CBG is that of [PPLM](https://arxiv.org/abs/1912.02164)
(which was adapted for diffusion models in [NOS](https://arxiv.org/abs/2305.20009)).
To train these classifiers,
refer to the following shell script:
[`train_qm9_pplm_classifier.sh`](./scripts/train_qm9_pplm_classifier.sh)
(for both PPLM and NOS classifiers).
### Guidance Evaluation
<a name="guidance-eval"></a>
To evaluate guidance mechanisms, we load trained models
(and classifiers, if applicable)
and generate some number of samples
for which we compute "quality" metrics
(e.g., validity/novelty in the QM9 experiments)
and control label satisfaction (e.g., mean value of novel generated molecules for the property of interest in the QM9 experiments).
The scripts for these evaluations can be found in the [`guidance_eval/`](./guidance_eval) directory.
To run these evaluations, please refer to the following shell scripts:
- QM9: [`eval_qm9_guidance.sh`](./guidance_eval/eval_qm9_guidance.sh)
- Species10: [`eval_ten_species_guidance.sh`](./guidance_eval/eval_ten_species_guidance.sh)
- For this dataset, we also evaluate the accuracy of a HyenaDNA classifier on correctly classifying generated sequences.
This model can be trained using [`train_ten_species_eval_classifier.sh`](./scripts/train_ten_species_eval_classifier.sh).
- To see how this trained evaluation classifier performs on the validation set of the original data use this notebook [`eval_hyenadna_classifier.ipynb`](./notebooks/eval_hyenadna_classifier.ipynb).
In the paper,
we performed an extensive hyperparameter sweep for our proposed guidance mechanisms and for baselines.
The shell scripts can be used
to reproduce these experiments,
e.g., for the D-CFG experiments on QM9:
```bash
export MODEL=<ar|mdlm|udlm>
export PROP=<qed|ring_count>
export GUIDANCE=cfg
for GAMMA in $(seq 1 5); do
sbatch \
--export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},GAMMA=${GAMMA} \
--job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL}_GAMMA-${GAMMA} \
eval_qm9_guidance.sh
done
```
Once each evaluation run is complete,
a `.csv` file
containing the results is saved in the run directory of the trained generative model.
## Using HuggingFace Models
<a name="hf_models"></a>
We provide pre-trained models on HuggingFace 🤗:
- UDLM trained on LM1B: [kuleshov-group/udlm-lm1b](https://huggingface.co/kuleshov-group/udlm-lm1b)
- UDLM trained on QM9: [kuleshov-group/udlm-qm9](https://huggingface.co/kuleshov-group/udlm-qm9)
- Note: this model was trained without guidance and can be used with classifier-free guidance.
Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.
To use these models, you can load them using the HuggingFace API, e.g.,
```python
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/udlm-lm1b")
```
To use these models in our repository, set the following `config` parameters:
```bash
backbone="hf_dit"
model="hf"
model.pretrained_model_name_or_path="kuleshov-group/udlm-lm1b" # or "kuleshov-group/udlm-qm9"
```
## Acknowledgements
<a name="acknowledgements"></a>
This repository was built off of [MDLM](https://github.com/kuleshov-group/mdlm),
which in used [SEDD](https://github.com/louaaron/Score-Entropy-Discrete-Diffusion).
Our code implementation of D-CBG is adapted from Nisonoff et al.'s [repo](https://github.com/hnisonoff/discrete_guidance).
## Citation
<a name="citation"></a>
```
@article{
schiff2024discreteguidance,
title={Simple Guidance Mechanisms for Discrete Diffusion Models},
author={Schiff, Yair and Sahoo, Subham Sekhar and Phung, Hao and Wang, Guanghan and Boshar, Sam and Dalla-torre, Hugo and de Almeida, Bernardo P and Rush, Alexander and Pierrot, Thomas and Kuleshov, Volodymyr},
journal={arXiv preprint arXiv:2412.10193},
year={2024}
}
```
|