import torch import numpy as np import cv2 import onnxruntime as ort from utils.preprocess import preprocess_image from utils.model import load_model class PredictionEngine: def __init__(self, model_path=None, use_onnx=True, input_size=256): """ Initialize the prediction engine Args: model_path: Path to the model file (PyTorch or ONNX) use_onnx: Whether to use ONNX runtime for inference input_size: Input size for the model (default is 256) """ self.use_onnx = use_onnx self.input_size = input_size if model_path: if use_onnx: self.model = self._load_onnx_model(model_path) else: self.model = load_model(model_path) else: self.model = None def _load_onnx_model(self, model_path): """ Load an ONNX model Args: model_path: Path to the ONNX model Returns: ONNX Runtime InferenceSession """ # Try with CUDA first, fall back to CPU if needed try: session = ort.InferenceSession( model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] ) print("ONNX model loaded with CUDA support") return session except Exception as e: print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}") session = ort.InferenceSession( model_path, providers=["CPUExecutionProvider"] ) print("ONNX model loaded with CPU support") return session def preprocess(self, image): """ Preprocess an image for prediction Args: image: Input image (numpy array) Returns: Processed image suitable for the model """ # Keep the original image for reference self.original_shape = image.shape[:2] # Preprocess image if self.use_onnx: # For ONNX, we need to ensure the input is exactly the expected size tensor = preprocess_image(image, img_size=self.input_size) return tensor.numpy() else: # For PyTorch return preprocess_image(image, img_size=self.input_size) def predict(self, image): """ Make a prediction on an image Args: image: Input image (numpy array) Returns: Predicted mask """ if self.model is None: raise ValueError("Model not loaded. Initialize with a valid model path.") # Preprocess the image processed_input = self.preprocess(image) # Run inference if self.use_onnx: # Get input and output names input_name = self.model.get_inputs()[0].name output_name = self.model.get_outputs()[0].name # Run ONNX inference outputs = self.model.run([output_name], {input_name: processed_input}) # Apply sigmoid to output mask = 1 / (1 + np.exp(-outputs[0].squeeze())) else: # PyTorch inference with torch.no_grad(): # Move to device device = next(self.model.parameters()).device processed_input = processed_input.to(device) # Forward pass output = self.model(processed_input) output = torch.sigmoid(output) # Convert to numpy mask = output.cpu().numpy().squeeze() return mask def load_pytorch_model(model_path): """ Load the PyTorch model for prediction Args: model_path: Path to the PyTorch model Returns: PredictionEngine instance """ return PredictionEngine(model_path, use_onnx=False) def load_onnx_model(model_path, input_size=256): """ Load the ONNX model for prediction Args: model_path: Path to the ONNX model input_size: Input size for the model Returns: PredictionEngine instance """ return PredictionEngine(model_path, use_onnx=True, input_size=input_size) # For backwards compatibility def predict(model, image): """ Legacy function for prediction Args: model: Model instance image: Input image Returns: Predicted mask """ if isinstance(model, PredictionEngine): return model.predict(image) engine = PredictionEngine(use_onnx=True) engine.model = model return engine.predict(image)