Spaces:
Runtime error
Runtime error
feat: cleanup training script
Browse files- dev/seq2seq/run_seq2seq_flax.py +26 -65
dev/seq2seq/run_seq2seq_flax.py
CHANGED
|
@@ -23,21 +23,19 @@ import os
|
|
| 23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 24 |
import sys
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
-
from functools import partial
|
| 27 |
from pathlib import Path
|
| 28 |
from typing import Callable, Optional
|
| 29 |
import json
|
| 30 |
|
| 31 |
import datasets
|
| 32 |
import numpy as np
|
| 33 |
-
from datasets import Dataset, load_dataset
|
| 34 |
from tqdm import tqdm
|
| 35 |
|
| 36 |
import jax
|
| 37 |
import jax.numpy as jnp
|
| 38 |
import optax
|
| 39 |
import transformers
|
| 40 |
-
from filelock import FileLock
|
| 41 |
from flax import jax_utils, traverse_util
|
| 42 |
from flax.serialization import from_bytes, to_bytes
|
| 43 |
import flax.linen as nn
|
|
@@ -45,15 +43,12 @@ from flax.jax_utils import unreplicate
|
|
| 45 |
from flax.training import train_state
|
| 46 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
| 47 |
from transformers import (
|
| 48 |
-
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
| 49 |
AutoTokenizer,
|
| 50 |
-
FlaxAutoModelForSeq2SeqLM,
|
| 51 |
FlaxBartForConditionalGeneration,
|
| 52 |
HfArgumentParser,
|
| 53 |
TrainingArguments,
|
| 54 |
)
|
| 55 |
from transformers.models.bart.modeling_flax_bart import *
|
| 56 |
-
from transformers.file_utils import is_offline_mode
|
| 57 |
|
| 58 |
import wandb
|
| 59 |
|
|
@@ -62,10 +57,6 @@ from dalle_mini.text import TextNormalizer
|
|
| 62 |
logger = pylogging.getLogger(__name__)
|
| 63 |
|
| 64 |
|
| 65 |
-
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys())
|
| 66 |
-
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
# Model hyperparameters, for convenience
|
| 70 |
# TODO: the model has now it's own definition file and should be imported
|
| 71 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
|
@@ -87,25 +78,12 @@ class ModelArguments:
|
|
| 87 |
"Don't set if you want to train a model from scratch."
|
| 88 |
},
|
| 89 |
)
|
| 90 |
-
model_type: Optional[str] = field(
|
| 91 |
-
default=None,
|
| 92 |
-
metadata={
|
| 93 |
-
"help": "If training from scratch, pass a model type from the list: "
|
| 94 |
-
+ ", ".join(MODEL_TYPES)
|
| 95 |
-
},
|
| 96 |
-
)
|
| 97 |
config_name: Optional[str] = field(
|
| 98 |
default=None,
|
| 99 |
metadata={
|
| 100 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 101 |
},
|
| 102 |
)
|
| 103 |
-
cache_dir: Optional[str] = field(
|
| 104 |
-
default=None,
|
| 105 |
-
metadata={
|
| 106 |
-
"help": "Where do you want to store the pretrained models downloaded from s3"
|
| 107 |
-
},
|
| 108 |
-
)
|
| 109 |
use_fast_tokenizer: bool = field(
|
| 110 |
default=True,
|
| 111 |
metadata={
|
|
@@ -281,6 +259,19 @@ class TrainState(train_state.TrainState):
|
|
| 281 |
dropout_rng=shard_prng_key(self.dropout_rng)
|
| 282 |
)
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 286 |
def setup(self):
|
|
@@ -480,22 +471,6 @@ def main():
|
|
| 480 |
streaming=data_args.streaming,
|
| 481 |
)
|
| 482 |
|
| 483 |
-
# Set up items to load or create
|
| 484 |
-
tokenizer = None
|
| 485 |
-
artifact_dir = None
|
| 486 |
-
|
| 487 |
-
def restore_state(state, artifact_dir):
|
| 488 |
-
# restore optimizer state
|
| 489 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 490 |
-
opt_state = from_bytes(state.opt_state, f.read())
|
| 491 |
-
|
| 492 |
-
# restore steps
|
| 493 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 494 |
-
training_state = json.load(f)
|
| 495 |
-
step = training_state["step"]
|
| 496 |
-
|
| 497 |
-
return step, opt_state
|
| 498 |
-
|
| 499 |
# Set up wandb run
|
| 500 |
wandb.init(
|
| 501 |
entity="dalle-mini",
|
|
@@ -510,22 +485,11 @@ def main():
|
|
| 510 |
artifact_dir = artifact.download()
|
| 511 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
| 512 |
|
| 513 |
-
#
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
# used in the preprocessing function
|
| 520 |
-
config = model.config
|
| 521 |
-
|
| 522 |
-
# load tokenizer if present
|
| 523 |
-
if (Path(artifact_dir) / "tokenizer_config.json").exists():
|
| 524 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 525 |
-
model_args.model_name_or_path,
|
| 526 |
-
cache_dir=model_args.cache_dir,
|
| 527 |
-
use_fast=model_args.use_fast_tokenizer,
|
| 528 |
-
)
|
| 529 |
|
| 530 |
else:
|
| 531 |
# Set up our new model config
|
|
@@ -552,11 +516,9 @@ def main():
|
|
| 552 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 553 |
)
|
| 554 |
|
| 555 |
-
|
| 556 |
-
if tokenizer is None:
|
| 557 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 558 |
model_args.model_name_or_path,
|
| 559 |
-
cache_dir=model_args.cache_dir,
|
| 560 |
use_fast=model_args.use_fast_tokenizer,
|
| 561 |
)
|
| 562 |
|
|
@@ -609,7 +571,9 @@ def main():
|
|
| 609 |
model_inputs["labels"] = labels
|
| 610 |
|
| 611 |
# In our case, this prepends the bos token and removes the last one
|
| 612 |
-
decoder_input_ids = shift_tokens_right(
|
|
|
|
|
|
|
| 613 |
model_inputs["decoder_input_ids"] = decoder_input_ids
|
| 614 |
|
| 615 |
return model_inputs
|
|
@@ -787,8 +751,7 @@ def main():
|
|
| 787 |
)
|
| 788 |
if model_args.from_checkpoint is not None:
|
| 789 |
# restore optimizer state and step
|
| 790 |
-
|
| 791 |
-
state = state.replace(step=step, opt_state=opt_state)
|
| 792 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
| 793 |
|
| 794 |
# label smoothed cross entropy
|
|
@@ -974,16 +937,14 @@ def main():
|
|
| 974 |
for epoch in epochs:
|
| 975 |
# ======================== Training ================================
|
| 976 |
step = unreplicate(state.step)
|
| 977 |
-
wandb_log({"train/epoch": epoch}, step=step)
|
| 978 |
-
|
| 979 |
-
# Create sampling rng
|
| 980 |
-
rng, input_rng = jax.random.split(rng)
|
| 981 |
|
| 982 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 983 |
if data_args.streaming:
|
| 984 |
train_dataset.set_epoch(epoch)
|
| 985 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
| 986 |
else:
|
|
|
|
| 987 |
train_loader = data_loader(
|
| 988 |
input_rng, train_dataset, train_batch_size, shuffle=True
|
| 989 |
)
|
|
|
|
| 23 |
import logging as pylogging # To avoid collision with transformers.utils.logging
|
| 24 |
import sys
|
| 25 |
from dataclasses import dataclass, field
|
|
|
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Callable, Optional
|
| 28 |
import json
|
| 29 |
|
| 30 |
import datasets
|
| 31 |
import numpy as np
|
| 32 |
+
from datasets import Dataset, load_dataset
|
| 33 |
from tqdm import tqdm
|
| 34 |
|
| 35 |
import jax
|
| 36 |
import jax.numpy as jnp
|
| 37 |
import optax
|
| 38 |
import transformers
|
|
|
|
| 39 |
from flax import jax_utils, traverse_util
|
| 40 |
from flax.serialization import from_bytes, to_bytes
|
| 41 |
import flax.linen as nn
|
|
|
|
| 43 |
from flax.training import train_state
|
| 44 |
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
| 45 |
from transformers import (
|
|
|
|
| 46 |
AutoTokenizer,
|
|
|
|
| 47 |
FlaxBartForConditionalGeneration,
|
| 48 |
HfArgumentParser,
|
| 49 |
TrainingArguments,
|
| 50 |
)
|
| 51 |
from transformers.models.bart.modeling_flax_bart import *
|
|
|
|
| 52 |
|
| 53 |
import wandb
|
| 54 |
|
|
|
|
| 57 |
logger = pylogging.getLogger(__name__)
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
# Model hyperparameters, for convenience
|
| 61 |
# TODO: the model has now it's own definition file and should be imported
|
| 62 |
OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
|
|
|
|
| 78 |
"Don't set if you want to train a model from scratch."
|
| 79 |
},
|
| 80 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
config_name: Optional[str] = field(
|
| 82 |
default=None,
|
| 83 |
metadata={
|
| 84 |
"help": "Pretrained config name or path if not the same as model_name"
|
| 85 |
},
|
| 86 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
use_fast_tokenizer: bool = field(
|
| 88 |
default=True,
|
| 89 |
metadata={
|
|
|
|
| 259 |
dropout_rng=shard_prng_key(self.dropout_rng)
|
| 260 |
)
|
| 261 |
|
| 262 |
+
def restore_state(self, artifact_dir):
|
| 263 |
+
# restore optimizer state
|
| 264 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 265 |
+
opt_state = from_bytes(self.opt_state, f.read())
|
| 266 |
+
|
| 267 |
+
# restore steps
|
| 268 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 269 |
+
training_state = json.load(f)
|
| 270 |
+
step = training_state["step"]
|
| 271 |
+
|
| 272 |
+
# replace state
|
| 273 |
+
return self.replace(step=step, opt_state=opt_state)
|
| 274 |
+
|
| 275 |
|
| 276 |
class CustomFlaxBartModule(FlaxBartModule):
|
| 277 |
def setup(self):
|
|
|
|
| 471 |
streaming=data_args.streaming,
|
| 472 |
)
|
| 473 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
# Set up wandb run
|
| 475 |
wandb.init(
|
| 476 |
entity="dalle-mini",
|
|
|
|
| 485 |
artifact_dir = artifact.download()
|
| 486 |
model = CustomFlaxBartForConditionalGeneration.from_pretrained(artifact_dir)
|
| 487 |
|
| 488 |
+
# load tokenizer
|
| 489 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 490 |
+
artifact_dir,
|
| 491 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 492 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
else:
|
| 495 |
# Set up our new model config
|
|
|
|
| 516 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
| 517 |
)
|
| 518 |
|
| 519 |
+
# Load tokenizer
|
|
|
|
| 520 |
tokenizer = AutoTokenizer.from_pretrained(
|
| 521 |
model_args.model_name_or_path,
|
|
|
|
| 522 |
use_fast=model_args.use_fast_tokenizer,
|
| 523 |
)
|
| 524 |
|
|
|
|
| 571 |
model_inputs["labels"] = labels
|
| 572 |
|
| 573 |
# In our case, this prepends the bos token and removes the last one
|
| 574 |
+
decoder_input_ids = shift_tokens_right(
|
| 575 |
+
labels, model.config.decoder_start_token_id
|
| 576 |
+
)
|
| 577 |
model_inputs["decoder_input_ids"] = decoder_input_ids
|
| 578 |
|
| 579 |
return model_inputs
|
|
|
|
| 751 |
)
|
| 752 |
if model_args.from_checkpoint is not None:
|
| 753 |
# restore optimizer state and step
|
| 754 |
+
state = state.restore_state(artifact_dir)
|
|
|
|
| 755 |
# TODO: number of remaining training epochs/steps and dataloader state need to be adjusted
|
| 756 |
|
| 757 |
# label smoothed cross entropy
|
|
|
|
| 937 |
for epoch in epochs:
|
| 938 |
# ======================== Training ================================
|
| 939 |
step = unreplicate(state.step)
|
| 940 |
+
# wandb_log({"train/epoch": epoch}, step=step)
|
|
|
|
|
|
|
|
|
|
| 941 |
|
| 942 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 943 |
if data_args.streaming:
|
| 944 |
train_dataset.set_epoch(epoch)
|
| 945 |
train_loader = data_loader_streaming(train_dataset, train_batch_size)
|
| 946 |
else:
|
| 947 |
+
rng, input_rng = jax.random.split(rng)
|
| 948 |
train_loader = data_loader(
|
| 949 |
input_rng, train_dataset, train_batch_size, shuffle=True
|
| 950 |
)
|