🎨 [Update] progress and tqdm
Browse files- yolo/tools/log_helper.py +3 -3
- yolo/utils/dataloader.py +2 -1
yolo/tools/log_helper.py
CHANGED
|
@@ -47,7 +47,7 @@ class CustomProgress:
|
|
| 47 |
self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
|
| 48 |
|
| 49 |
def start_train(self, num_epochs: int):
|
| 50 |
-
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
| 51 |
|
| 52 |
def one_epoch(self):
|
| 53 |
self.progress.update(self.task_epoch, advance=1)
|
|
@@ -63,9 +63,9 @@ class CustomProgress:
|
|
| 63 |
for loss_name, loss_value in loss_dict.items():
|
| 64 |
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
| 65 |
|
| 66 |
-
loss_str = "
|
| 67 |
for loss_name, loss_val in loss_dict.items():
|
| 68 |
-
loss_str += f" {
|
| 69 |
|
| 70 |
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
| 71 |
|
|
|
|
| 47 |
self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
|
| 48 |
|
| 49 |
def start_train(self, num_epochs: int):
|
| 50 |
+
self.task_epoch = self.progress.add_task("[cyan]Epochs [white]| Loss | Box | DFL | BCE |", total=num_epochs)
|
| 51 |
|
| 52 |
def one_epoch(self):
|
| 53 |
self.progress.update(self.task_epoch, advance=1)
|
|
|
|
| 63 |
for loss_name, loss_value in loss_dict.items():
|
| 64 |
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
| 65 |
|
| 66 |
+
loss_str = "| -.-- |"
|
| 67 |
for loss_name, loss_val in loss_dict.items():
|
| 68 |
+
loss_str += f" {loss_val:2.2f} |"
|
| 69 |
|
| 70 |
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
| 71 |
|
yolo/utils/dataloader.py
CHANGED
|
@@ -7,6 +7,7 @@ import numpy as np
|
|
| 7 |
import torch
|
| 8 |
from loguru import logger
|
| 9 |
from PIL import Image
|
|
|
|
| 10 |
from torch.utils.data import DataLoader, Dataset
|
| 11 |
from torchvision.transforms import functional as TF
|
| 12 |
from tqdm.rich import tqdm
|
|
@@ -74,7 +75,7 @@ class YoloDataset(Dataset):
|
|
| 74 |
|
| 75 |
data = []
|
| 76 |
valid_inputs = 0
|
| 77 |
-
for image_name in
|
| 78 |
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 79 |
continue
|
| 80 |
image_id, _ = path.splitext(image_name)
|
|
|
|
| 7 |
import torch
|
| 8 |
from loguru import logger
|
| 9 |
from PIL import Image
|
| 10 |
+
from rich.progress import track
|
| 11 |
from torch.utils.data import DataLoader, Dataset
|
| 12 |
from torchvision.transforms import functional as TF
|
| 13 |
from tqdm.rich import tqdm
|
|
|
|
| 75 |
|
| 76 |
data = []
|
| 77 |
valid_inputs = 0
|
| 78 |
+
for image_name in track(images_list, description="Filtering data"):
|
| 79 |
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 80 |
continue
|
| 81 |
image_id, _ = path.splitext(image_name)
|