| import sys | |
| from pathlib import Path | |
| import hydra | |
| from lightning import Trainer | |
| project_root = Path(__file__).resolve().parent.parent | |
| sys.path.append(str(project_root)) | |
| from yolo.config.config import Config | |
| from yolo.tools.solver import TrainModel, ValidateModel | |
| from yolo.utils.logging_utils import setup | |
| def main(cfg: Config): | |
| callbacks, loggers = setup(cfg) | |
| trainer = Trainer( | |
| accelerator="cuda", | |
| max_epochs=getattr(cfg.task, "epoch", None), | |
| precision="16-mixed", | |
| callbacks=callbacks, | |
| logger=loggers, | |
| log_every_n_steps=1, | |
| gradient_clip_val=10, | |
| ) | |
| match cfg.task.task: | |
| case "train": | |
| model = TrainModel(cfg) | |
| trainer.fit(model) | |
| case "validation": | |
| model = ValidateModel(cfg) | |
| trainer.validate(model) | |
| if __name__ == "__main__": | |
| main() | |