π [Merge] branch 'MODELv2' into INFERENCE
Browse files- yolo/lazy.py +1 -1
- yolo/model/yolo.py +6 -4
yolo/lazy.py
CHANGED
|
@@ -26,7 +26,7 @@ def main(cfg: Config):
|
|
| 26 |
model = FastModelLoader(cfg).load_model()
|
| 27 |
device = torch.device(cfg.device)
|
| 28 |
else:
|
| 29 |
-
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight
|
| 30 |
|
| 31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 32 |
|
|
|
|
| 26 |
model = FastModelLoader(cfg).load_model()
|
| 27 |
device = torch.device(cfg.device)
|
| 28 |
else:
|
| 29 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
| 30 |
|
| 31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 32 |
|
yolo/model/yolo.py
CHANGED
|
@@ -2,9 +2,9 @@ import os
|
|
| 2 |
from typing import Any, Dict, List, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
-
import torch.nn as nn
|
| 6 |
from loguru import logger
|
| 7 |
from omegaconf import ListConfig, OmegaConf
|
|
|
|
| 8 |
|
| 9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
| 10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
@@ -117,7 +117,9 @@ class YOLO(nn.Module):
|
|
| 117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 118 |
|
| 119 |
|
| 120 |
-
def create_model(
|
|
|
|
|
|
|
| 121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 122 |
|
| 123 |
Args:
|
|
@@ -133,9 +135,9 @@ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str =
|
|
| 133 |
if not os.path.exists(weight_path):
|
| 134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 135 |
prepare_weight(weight_path=weight_path)
|
| 136 |
-
model.model.load_state_dict(torch.load(weight_path))
|
| 137 |
logger.info("β
Success load model weight")
|
| 138 |
|
| 139 |
log_model_structure(model.model)
|
| 140 |
draw_model(model=model)
|
| 141 |
-
return model
|
|
|
|
| 2 |
from typing import Any, Dict, List, Union
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import ListConfig, OmegaConf
|
| 7 |
+
from torch import device, nn
|
| 8 |
|
| 9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
| 10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
|
| 117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 118 |
|
| 119 |
|
| 120 |
+
def create_model(
|
| 121 |
+
model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
|
| 122 |
+
) -> YOLO:
|
| 123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 124 |
|
| 125 |
Args:
|
|
|
|
| 135 |
if not os.path.exists(weight_path):
|
| 136 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 137 |
prepare_weight(weight_path=weight_path)
|
| 138 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device))
|
| 139 |
logger.info("β
Success load model weight")
|
| 140 |
|
| 141 |
log_model_structure(model.model)
|
| 142 |
draw_model(model=model)
|
| 143 |
+
return model.to(device)
|