SemanticSegmentationModel
/
semantic-segmentation
/SemanticModel
/.ipynb_checkpoints
/model_core-checkpoint.py
import torch | |
import torch.nn as nn | |
import segmentation_models_pytorch as smp | |
from segmentation_models_pytorch import utils | |
from SemanticModel.encoder_management import initialize_encoder | |
from SemanticModel.custom_losses import FocalLossFunction, TverskyLossFunction, EnhancedCrossEntropy | |
from SemanticModel.image_preprocessing import get_preprocessing_pipeline | |
class SegmentationModel: | |
def __init__(self, classes=['background', 'foreground'], architecture='unet', | |
encoder='timm-regnety_120', weights='imagenet', loss=None): | |
self._initialize_classes(classes) | |
self.architecture = architecture | |
self.encoder = encoder | |
self.weights = weights | |
self._setup_loss_function(loss) | |
self._initialize_model() | |
def _initialize_classes(self, classes): | |
"""Sets up class configuration.""" | |
if len(classes) <= 2: | |
self.classes = [c for c in classes if c.lower() != 'background'] | |
self.class_values = [i for i, c in enumerate(classes) if c.lower() != 'background'] | |
self.background_flag = 'background' in classes | |
else: | |
self.classes = classes | |
self.class_values = list(range(len(classes))) | |
self.background_flag = False | |
self.n_classes = len(self.classes) | |
def _setup_loss_function(self, loss): | |
"""Configures model's loss function.""" | |
if not loss: | |
loss = 'bce_with_logits' if self.n_classes > 1 else 'dice' | |
if loss.lower() not in ['dice', 'bce_with_logits', 'focal', 'tversky']: | |
print(f'Invalid loss: {loss}, defaulting to dice') | |
loss = 'dice' | |
loss_configs = { | |
'bce_with_logits': { | |
'activation': None, | |
'loss': EnhancedCrossEntropy() if self.n_classes > 1 else utils.losses.BCEWithLogitsLoss() | |
}, | |
'dice': { | |
'activation': 'softmax' if self.n_classes > 1 else 'sigmoid', | |
'loss': utils.losses.DiceLoss() | |
}, | |
'focal': { | |
'activation': None, | |
'loss': FocalLossFunction() | |
}, | |
'tversky': { | |
'activation': None, | |
'loss': TverskyLossFunction() | |
} | |
} | |
config = loss_configs[loss.lower()] | |
self.activation = config['activation'] | |
self.loss = config['loss'] | |
self.loss_name = loss | |
def _initialize_model(self): | |
"""Initializes the segmentation model architecture.""" | |
if self.weights.endswith('pth'): | |
self._load_pretrained_model() | |
else: | |
self._create_new_model() | |
def _load_pretrained_model(self): | |
"""Loads model from pretrained weights.""" | |
print('Loading pretrained model...') | |
self.model = torch.load(self.weights) | |
if isinstance(self.model, torch.nn.DataParallel): | |
self.model = self.model.module | |
try: | |
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') | |
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) | |
except: | |
print('Failed to configure preprocessing. Setting to None.') | |
self.preprocessing = None | |
def _create_new_model(self): | |
"""Creates new model with specified architecture.""" | |
preprocessing_fn = smp.encoders.get_preprocessing_fn(self.encoder, 'imagenet') | |
self.preprocessing = get_preprocessing_pipeline(preprocessing_fn) | |
initialize_encoder(name=self.encoder, weights=self.weights) | |
architectures = { | |
'unet': smp.Unet, | |
'unet++': smp.UnetPlusPlus, | |
'deeplabv3': smp.DeepLabV3, | |
'deeplabv3+': smp.DeepLabV3Plus, | |
'fpn': smp.FPN, | |
'linknet': smp.Linknet, | |
'manet': smp.MAnet, | |
'pan': smp.PAN, | |
'pspnet': smp.PSPNet | |
} | |
if self.architecture not in architectures: | |
raise ValueError(f'Unsupported architecture: {self.architecture}') | |
self.model = architectures[self.architecture]( | |
encoder_name=self.encoder, | |
encoder_weights=self.weights, | |
classes=self.n_classes, | |
activation=self.activation | |
) | |
def config_data(self): | |
"""Returns model configuration data.""" | |
return { | |
'architecture': self.architecture, | |
'encoder': self.encoder, | |
'weights': self.weights, | |
'activation': self.activation, | |
'loss': self.loss_name, | |
'classes': ['background'] + self.classes if self.background_flag else self.classes | |
} | |
def list_architectures(): | |
"""Returns available architecture options.""" | |
return ['unet', 'unet++', 'deeplabv3', 'deeplabv3+', 'fpn', | |
'linknet', 'manet', 'pan', 'pspnet'] |