Spaces:
Sleeping
Sleeping
import argparse | |
import json | |
from functools import partial | |
from accelerate.state import PartialState | |
from datasets import load_dataset, load_from_disk | |
from huggingface_hub import HfApi | |
from transformers import ( | |
AutoConfig, | |
AutoImageProcessor, | |
AutoModelForObjectDetection, | |
EarlyStoppingCallback, | |
Trainer, | |
TrainingArguments, | |
) | |
from transformers.trainer_callback import PrinterCallback | |
from autotrain import logger | |
from autotrain.trainers.common import ( | |
ALLOW_REMOTE_CODE, | |
LossLoggingCallback, | |
TrainStartCallback, | |
UploadLogs, | |
monitor, | |
pause_space, | |
remove_autotrain_data, | |
save_training_params, | |
) | |
from autotrain.trainers.object_detection import utils | |
from autotrain.trainers.object_detection.params import ObjectDetectionParams | |
def parse_args(): | |
# get training_config.json from the end user | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--training_config", type=str, required=True) | |
return parser.parse_args() | |
def train(config): | |
if isinstance(config, dict): | |
config = ObjectDetectionParams(**config) | |
valid_data = None | |
if config.data_path == f"{config.project_name}/autotrain-data": | |
train_data = load_from_disk(config.data_path)[config.train_split] | |
else: | |
if ":" in config.train_split: | |
dataset_config_name, split = config.train_split.split(":") | |
train_data = load_dataset( | |
config.data_path, | |
name=dataset_config_name, | |
split=split, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
else: | |
train_data = load_dataset( | |
config.data_path, | |
split=config.train_split, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
if config.valid_split is not None: | |
if config.data_path == f"{config.project_name}/autotrain-data": | |
valid_data = load_from_disk(config.data_path)[config.valid_split] | |
else: | |
if ":" in config.valid_split: | |
dataset_config_name, split = config.valid_split.split(":") | |
valid_data = load_dataset( | |
config.data_path, | |
name=dataset_config_name, | |
split=split, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
else: | |
valid_data = load_dataset( | |
config.data_path, | |
split=config.valid_split, | |
token=config.token, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
logger.info(f"Train data: {train_data}") | |
logger.info(f"Valid data: {valid_data}") | |
categories = train_data.features[config.objects_column].feature["category"].names | |
id2label = dict(enumerate(categories)) | |
label2id = {v: k for k, v in id2label.items()} | |
model_config = AutoConfig.from_pretrained( | |
config.model, | |
label2id=label2id, | |
id2label=id2label, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
token=config.token, | |
) | |
try: | |
model = AutoModelForObjectDetection.from_pretrained( | |
config.model, | |
config=model_config, | |
ignore_mismatched_sizes=True, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
token=config.token, | |
) | |
except OSError: | |
model = AutoModelForObjectDetection.from_pretrained( | |
config.model, | |
config=model_config, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
token=config.token, | |
ignore_mismatched_sizes=True, | |
from_tf=True, | |
) | |
image_processor = AutoImageProcessor.from_pretrained( | |
config.model, | |
token=config.token, | |
do_pad=False, | |
do_resize=False, | |
size={"longest_edge": config.image_square_size}, | |
trust_remote_code=ALLOW_REMOTE_CODE, | |
) | |
train_data, valid_data = utils.process_data(train_data, valid_data, image_processor, config) | |
if config.logging_steps == -1: | |
if config.valid_split is not None: | |
logging_steps = int(0.2 * len(valid_data) / config.batch_size) | |
else: | |
logging_steps = int(0.2 * len(train_data) / config.batch_size) | |
if logging_steps == 0: | |
logging_steps = 1 | |
if logging_steps > 25: | |
logging_steps = 25 | |
config.logging_steps = logging_steps | |
else: | |
logging_steps = config.logging_steps | |
logger.info(f"Logging steps: {logging_steps}") | |
training_args = dict( | |
output_dir=config.project_name, | |
per_device_train_batch_size=config.batch_size, | |
per_device_eval_batch_size=2 * config.batch_size, | |
learning_rate=config.lr, | |
num_train_epochs=config.epochs, | |
eval_strategy=config.eval_strategy if config.valid_split is not None else "no", | |
logging_steps=logging_steps, | |
save_total_limit=config.save_total_limit, | |
save_strategy=config.eval_strategy if config.valid_split is not None else "no", | |
gradient_accumulation_steps=config.gradient_accumulation, | |
report_to=config.log, | |
auto_find_batch_size=config.auto_find_batch_size, | |
lr_scheduler_type=config.scheduler, | |
optim=config.optimizer, | |
warmup_ratio=config.warmup_ratio, | |
weight_decay=config.weight_decay, | |
max_grad_norm=config.max_grad_norm, | |
push_to_hub=False, | |
load_best_model_at_end=True if config.valid_split is not None else False, | |
ddp_find_unused_parameters=False, | |
) | |
if config.mixed_precision == "fp16": | |
training_args["fp16"] = True | |
if config.mixed_precision == "bf16": | |
training_args["bf16"] = True | |
if config.valid_split is not None: | |
training_args["eval_do_concat_batches"] = False | |
early_stop = EarlyStoppingCallback( | |
early_stopping_patience=config.early_stopping_patience, | |
early_stopping_threshold=config.early_stopping_threshold, | |
) | |
callbacks_to_use = [early_stop] | |
else: | |
callbacks_to_use = [] | |
callbacks_to_use.extend([UploadLogs(config=config), LossLoggingCallback(), TrainStartCallback()]) | |
_compute_metrics_fn = partial( | |
utils.object_detection_metrics, image_processor=image_processor, id2label=id2label, threshold=0.0 | |
) | |
args = TrainingArguments(**training_args) | |
trainer_args = dict( | |
args=args, | |
model=model, | |
callbacks=callbacks_to_use, | |
data_collator=utils.collate_fn, | |
tokenizer=image_processor, | |
compute_metrics=_compute_metrics_fn, | |
) | |
trainer = Trainer( | |
**trainer_args, | |
train_dataset=train_data, | |
eval_dataset=valid_data, | |
) | |
trainer.remove_callback(PrinterCallback) | |
trainer.train() | |
logger.info("Finished training, saving model...") | |
trainer.save_model(config.project_name) | |
image_processor.save_pretrained(config.project_name) | |
model_card = utils.create_model_card(config, trainer) | |
# save model card to output directory as README.md | |
with open(f"{config.project_name}/README.md", "w") as f: | |
f.write(model_card) | |
if config.push_to_hub: | |
if PartialState().process_index == 0: | |
remove_autotrain_data(config) | |
save_training_params(config) | |
logger.info("Pushing model to hub...") | |
api = HfApi(token=config.token) | |
api.create_repo( | |
repo_id=f"{config.username}/{config.project_name}", repo_type="model", private=True, exist_ok=True | |
) | |
api.upload_folder( | |
folder_path=config.project_name, repo_id=f"{config.username}/{config.project_name}", repo_type="model" | |
) | |
if PartialState().process_index == 0: | |
pause_space(config) | |
if __name__ == "__main__": | |
_args = parse_args() | |
training_config = json.load(open(_args.training_config)) | |
_config = ObjectDetectionParams(**training_config) | |
train(_config) | |