File size: 4,497 Bytes
01e938d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import numpy as np
import cv2
import onnxruntime as ort
from utils.preprocess import preprocess_image


class PredictionEngine:
    def __init__(self, model_path=None, use_onnx=True, input_size=256):
        """
        Initialize the prediction engine

        Args:
            model_path: Path to the model file (PyTorch or ONNX)
            use_onnx: Whether to use ONNX runtime for inference
            input_size: Input size for the model (default is 256)
        """
        self.use_onnx = use_onnx
        self.input_size = input_size

        if model_path:
            if use_onnx:
                self.model = self._load_onnx_model(model_path)
            else:
                self.model = self._load_pytorch_model(model_path)
        else:
            self.model = None

    def _load_onnx_model(self, model_path):
        """
        Load an ONNX model

        Args:
            model_path: Path to the ONNX model

        Returns:
            ONNX Runtime InferenceSession
        """
        # Try with CUDA first, fall back to CPU if needed
        try:
            session = ort.InferenceSession(
                model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
            )
            print("ONNX model loaded with CUDA support")
            return session
        except Exception as e:
            print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
            session = ort.InferenceSession(
                model_path, providers=["CPUExecutionProvider"]
            )
            print("ONNX model loaded with CPU support")
            return session
            
    def _load_pytorch_model(self, model_path):
        """
        Load a PyTorch model

        Args:
            model_path: Path to the PyTorch model

        Returns:
            PyTorch model
        """
        from utils.model import load_model
        return load_model(model_path)

    def preprocess(self, image):
        """
        Preprocess an image for prediction

        Args:
            image: Input image (numpy array)

        Returns:
            Processed image suitable for the model
        """
        # Keep the original image for reference
        self.original_shape = image.shape[:2]

        # Preprocess image
        if self.use_onnx:
            # For ONNX, we need to ensure the input is exactly the expected size
            tensor = preprocess_image(image, img_size=self.input_size)
            return tensor.numpy()
        else:
            # For PyTorch
            return preprocess_image(image, img_size=self.input_size)

    def predict(self, image):
        """
        Make a prediction on an image

        Args:
            image: Input image (numpy array)

        Returns:
            Predicted mask
        """
        if self.model is None:
            raise ValueError("Model not loaded. Initialize with a valid model path.")

        # Preprocess the image
        processed_input = self.preprocess(image)

        # Run inference
        if self.use_onnx:
            # Get input and output names
            input_name = self.model.get_inputs()[0].name
            output_name = self.model.get_outputs()[0].name

            # Run ONNX inference
            outputs = self.model.run([output_name], {input_name: processed_input})

            # Apply sigmoid to output
            mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
        else:
            # PyTorch inference
            with torch.no_grad():
                # Move to device
                device = next(self.model.parameters()).device
                processed_input = processed_input.to(device)

                # Forward pass
                output = self.model(processed_input)
                output = torch.sigmoid(output)

                # Convert to numpy
                mask = output.cpu().numpy().squeeze()

        return mask


def load_pytorch_model(model_path):
    """
    Load the PyTorch model for prediction

    Args:
        model_path: Path to the PyTorch model

    Returns:
        PredictionEngine instance
    """
    return PredictionEngine(model_path, use_onnx=False)


def load_onnx_model(model_path, input_size=256):
    """
    Load the ONNX model for prediction

    Args:
        model_path: Path to the ONNX model
        input_size: Input size for the model

    Returns:
        PredictionEngine instance
    """
    return PredictionEngine(model_path, use_onnx=True, input_size=input_size)