Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| D-FINE: Redefine Regression Task of DETRs as Fine-grained Distribution Refinement | |
| Copyright (c) 2024 The D-FINE Authors. All Rights Reserved. | |
| --------------------------------------------------------------------------------- | |
| Modified from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright (c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| import datetime | |
| import json | |
| import time | |
| import torch | |
| from ..misc import dist_utils, stats | |
| from ._solver import BaseSolver | |
| from .det_engine import evaluate, train_one_epoch | |
| class DetSolver(BaseSolver): | |
| def fit(self): | |
| self.train() | |
| args = self.cfg | |
| metric_names = ["AP50:95", "AP50", "AP75", "APsmall", "APmedium", "APlarge"] | |
| if self.use_wandb: | |
| import wandb | |
| wandb.init( | |
| project=args.yaml_cfg["project_name"], | |
| name=args.yaml_cfg["exp_name"], | |
| config=args.yaml_cfg, | |
| ) | |
| wandb.watch(self.model) | |
| n_parameters, model_stats = stats(self.cfg) | |
| print(model_stats) | |
| print("-" * 42 + "Start training" + "-" * 43) | |
| top1 = 0 | |
| best_stat = { | |
| "epoch": -1, | |
| } | |
| if self.last_epoch > 0: | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, | |
| self.criterion, | |
| self.postprocessor, | |
| self.val_dataloader, | |
| self.evaluator, | |
| self.device, | |
| self.last_epoch, | |
| self.use_wandb | |
| ) | |
| for k in test_stats: | |
| best_stat["epoch"] = self.last_epoch | |
| best_stat[k] = test_stats[k][0] | |
| top1 = test_stats[k][0] | |
| print(f"best_stat: {best_stat}") | |
| best_stat_print = best_stat.copy() | |
| start_time = time.time() | |
| start_epoch = self.last_epoch + 1 | |
| for epoch in range(start_epoch, args.epochs): | |
| self.train_dataloader.set_epoch(epoch) | |
| # self.train_dataloader.dataset.set_epoch(epoch) | |
| if dist_utils.is_dist_available_and_initialized(): | |
| self.train_dataloader.sampler.set_epoch(epoch) | |
| if epoch == self.train_dataloader.collate_fn.stop_epoch: | |
| self.load_resume_state(str(self.output_dir / "best_stg1.pth")) | |
| if self.ema: | |
| self.ema.decay = self.train_dataloader.collate_fn.ema_restart_decay | |
| print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") | |
| train_stats = train_one_epoch( | |
| self.model, | |
| self.criterion, | |
| self.train_dataloader, | |
| self.optimizer, | |
| self.device, | |
| epoch, | |
| max_norm=args.clip_max_norm, | |
| print_freq=args.print_freq, | |
| ema=self.ema, | |
| scaler=self.scaler, | |
| lr_warmup_scheduler=self.lr_warmup_scheduler, | |
| writer=self.writer, | |
| use_wandb=self.use_wandb, | |
| output_dir=self.output_dir, | |
| ) | |
| if self.lr_warmup_scheduler is None or self.lr_warmup_scheduler.finished(): | |
| self.lr_scheduler.step() | |
| self.last_epoch += 1 | |
| if self.output_dir and epoch < self.train_dataloader.collate_fn.stop_epoch: | |
| checkpoint_paths = [self.output_dir / "last.pth"] | |
| # extra checkpoint before LR drop and every 100 epochs | |
| if (epoch + 1) % args.checkpoint_freq == 0: | |
| checkpoint_paths.append(self.output_dir / f"checkpoint{epoch:04}.pth") | |
| for checkpoint_path in checkpoint_paths: | |
| dist_utils.save_on_master(self.state_dict(), checkpoint_path) | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, | |
| self.criterion, | |
| self.postprocessor, | |
| self.val_dataloader, | |
| self.evaluator, | |
| self.device, | |
| epoch, | |
| self.use_wandb, | |
| output_dir=self.output_dir, | |
| ) | |
| # TODO | |
| for k in test_stats: | |
| if self.writer and dist_utils.is_main_process(): | |
| for i, v in enumerate(test_stats[k]): | |
| self.writer.add_scalar(f"Test/{k}_{i}".format(k), v, epoch) | |
| if k in best_stat: | |
| best_stat["epoch"] = ( | |
| epoch if test_stats[k][0] > best_stat[k] else best_stat["epoch"] | |
| ) | |
| best_stat[k] = max(best_stat[k], test_stats[k][0]) | |
| else: | |
| best_stat["epoch"] = epoch | |
| best_stat[k] = test_stats[k][0] | |
| if best_stat[k] > top1: | |
| best_stat_print["epoch"] = epoch | |
| top1 = best_stat[k] | |
| if self.output_dir: | |
| if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| dist_utils.save_on_master( | |
| self.state_dict(), self.output_dir / "best_stg2.pth" | |
| ) | |
| else: | |
| dist_utils.save_on_master( | |
| self.state_dict(), self.output_dir / "best_stg1.pth" | |
| ) | |
| best_stat_print[k] = max(best_stat[k], top1) | |
| print(f"best_stat: {best_stat_print}") # global best | |
| if best_stat["epoch"] == epoch and self.output_dir: | |
| if epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| if test_stats[k][0] > top1: | |
| top1 = test_stats[k][0] | |
| dist_utils.save_on_master( | |
| self.state_dict(), self.output_dir / "best_stg2.pth" | |
| ) | |
| else: | |
| top1 = max(test_stats[k][0], top1) | |
| dist_utils.save_on_master( | |
| self.state_dict(), self.output_dir / "best_stg1.pth" | |
| ) | |
| elif epoch >= self.train_dataloader.collate_fn.stop_epoch: | |
| best_stat = { | |
| "epoch": -1, | |
| } | |
| if self.ema: | |
| self.ema.decay -= 0.0001 | |
| self.load_resume_state(str(self.output_dir / "best_stg1.pth")) | |
| print(f"Refresh EMA at epoch {epoch} with decay {self.ema.decay}") | |
| log_stats = { | |
| **{f"train_{k}": v for k, v in train_stats.items()}, | |
| **{f"test_{k}": v for k, v in test_stats.items()}, | |
| "epoch": epoch, | |
| "n_parameters": n_parameters, | |
| } | |
| if self.use_wandb: | |
| wandb_logs = {} | |
| for idx, metric_name in enumerate(metric_names): | |
| wandb_logs[f"metrics/{metric_name}"] = test_stats["coco_eval_bbox"][idx] | |
| wandb_logs["epoch"] = epoch | |
| wandb.log(wandb_logs) | |
| if self.output_dir and dist_utils.is_main_process(): | |
| with (self.output_dir / "log.txt").open("a") as f: | |
| f.write(json.dumps(log_stats) + "\n") | |
| # for evaluation logs | |
| if coco_evaluator is not None: | |
| (self.output_dir / "eval").mkdir(exist_ok=True) | |
| if "bbox" in coco_evaluator.coco_eval: | |
| filenames = ["latest.pth"] | |
| if epoch % 50 == 0: | |
| filenames.append(f"{epoch:03}.pth") | |
| for name in filenames: | |
| torch.save( | |
| coco_evaluator.coco_eval["bbox"].eval, | |
| self.output_dir / "eval" / name, | |
| ) | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| print("Training time {}".format(total_time_str)) | |
| def val(self): | |
| self.eval() | |
| module = self.ema.module if self.ema else self.model | |
| test_stats, coco_evaluator = evaluate( | |
| module, | |
| self.criterion, | |
| self.postprocessor, | |
| self.val_dataloader, | |
| self.evaluator, | |
| self.device, | |
| epoch=-1, | |
| use_wandb=False, | |
| ) | |
| if self.output_dir: | |
| dist_utils.save_on_master( | |
| coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth" | |
| ) | |
| return | |