|
"""
|
|
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
|
|
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
|
|
"""
|
|
|
|
import datetime
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from ..misc import dist_utils
|
|
from ._solver import BaseSolver
|
|
from .clas_engine import evaluate, train_one_epoch
|
|
|
|
|
|
class ClasSolver(BaseSolver):
|
|
def fit(
|
|
self,
|
|
):
|
|
print("Start training")
|
|
self.train()
|
|
args = self.cfg
|
|
|
|
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
print("Number of params:", n_parameters)
|
|
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(exist_ok=True)
|
|
|
|
start_time = time.time()
|
|
start_epoch = self.last_epoch + 1
|
|
for epoch in range(start_epoch, args.epochs):
|
|
if dist_utils.is_dist_available_and_initialized():
|
|
self.train_dataloader.sampler.set_epoch(epoch)
|
|
|
|
train_stats = train_one_epoch(
|
|
self.model,
|
|
self.criterion,
|
|
self.train_dataloader,
|
|
self.optimizer,
|
|
self.ema,
|
|
epoch=epoch,
|
|
device=self.device,
|
|
)
|
|
self.lr_scheduler.step()
|
|
self.last_epoch += 1
|
|
|
|
if output_dir:
|
|
checkpoint_paths = [output_dir / "checkpoint.pth"]
|
|
|
|
if (epoch + 1) % args.checkpoint_freq == 0:
|
|
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth")
|
|
for checkpoint_path in checkpoint_paths:
|
|
dist_utils.save_on_master(self.state_dict(epoch), checkpoint_path)
|
|
|
|
module = self.ema.module if self.ema else self.model
|
|
test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device)
|
|
|
|
log_stats = {
|
|
**{f"train_{k}": v for k, v in train_stats.items()},
|
|
**{f"test_{k}": v for k, v in test_stats.items()},
|
|
"epoch": epoch,
|
|
"n_parameters": n_parameters,
|
|
}
|
|
|
|
if output_dir and dist_utils.is_main_process():
|
|
with (output_dir / "log.txt").open("a") as f:
|
|
f.write(json.dumps(log_stats) + "\n")
|
|
|
|
total_time = time.time() - start_time
|
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
|
print("Training time {}".format(total_time_str))
|
|
|