| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| from hydra import compose, initialize | |
| project_root = Path(__file__).resolve().parent.parent.parent | |
| sys.path.append(str(project_root)) | |
| from yolo.utils.loss import YOLOLoss | |
| def cfg(): | |
| with initialize(config_path="../../yolo/config", version_base=None): | |
| cfg = compose(config_name="config") | |
| return cfg | |
| def loss_function(cfg) -> YOLOLoss: | |
| return YOLOLoss(cfg) | |
| def data(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| targets = torch.zeros(20, 6, device=device) | |
| predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)] | |
| return predicts, targets | |
| def test_yolo_loss(loss_function, data): | |
| predicts, targets = data | |
| loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets) | |
| assert torch.isnan(loss_iou) | |
| assert torch.isnan(loss_dfl) | |
| assert torch.isinf(loss_cls) | |