Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import cv2 | |
import onnxruntime as ort | |
from utils.preprocess import preprocess_image | |
from utils.model import load_model | |
class PredictionEngine: | |
def __init__(self, model_path=None, use_onnx=True, input_size=256): | |
""" | |
Initialize the prediction engine | |
Args: | |
model_path: Path to the model file (PyTorch or ONNX) | |
use_onnx: Whether to use ONNX runtime for inference | |
input_size: Input size for the model (default is 256) | |
""" | |
self.use_onnx = use_onnx | |
self.input_size = input_size | |
if model_path: | |
if use_onnx: | |
self.model = self._load_onnx_model(model_path) | |
else: | |
self.model = load_model(model_path) | |
else: | |
self.model = None | |
def _load_onnx_model(self, model_path): | |
""" | |
Load an ONNX model | |
Args: | |
model_path: Path to the ONNX model | |
Returns: | |
ONNX Runtime InferenceSession | |
""" | |
# Try with CUDA first, fall back to CPU if needed | |
try: | |
session = ort.InferenceSession( | |
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"] | |
) | |
print("ONNX model loaded with CUDA support") | |
return session | |
except Exception as e: | |
print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}") | |
session = ort.InferenceSession( | |
model_path, providers=["CPUExecutionProvider"] | |
) | |
print("ONNX model loaded with CPU support") | |
return session | |
def preprocess(self, image): | |
""" | |
Preprocess an image for prediction | |
Args: | |
image: Input image (numpy array) | |
Returns: | |
Processed image suitable for the model | |
""" | |
# Keep the original image for reference | |
self.original_shape = image.shape[:2] | |
# Preprocess image | |
if self.use_onnx: | |
# For ONNX, we need to ensure the input is exactly the expected size | |
tensor = preprocess_image(image, img_size=self.input_size) | |
return tensor.numpy() | |
else: | |
# For PyTorch | |
return preprocess_image(image, img_size=self.input_size) | |
def predict(self, image): | |
""" | |
Make a prediction on an image | |
Args: | |
image: Input image (numpy array) | |
Returns: | |
Predicted mask | |
""" | |
if self.model is None: | |
raise ValueError("Model not loaded. Initialize with a valid model path.") | |
# Preprocess the image | |
processed_input = self.preprocess(image) | |
# Run inference | |
if self.use_onnx: | |
# Get input and output names | |
input_name = self.model.get_inputs()[0].name | |
output_name = self.model.get_outputs()[0].name | |
# Run ONNX inference | |
outputs = self.model.run([output_name], {input_name: processed_input}) | |
# Apply sigmoid to output | |
mask = 1 / (1 + np.exp(-outputs[0].squeeze())) | |
else: | |
# PyTorch inference | |
with torch.no_grad(): | |
# Move to device | |
device = next(self.model.parameters()).device | |
processed_input = processed_input.to(device) | |
# Forward pass | |
output = self.model(processed_input) | |
output = torch.sigmoid(output) | |
# Convert to numpy | |
mask = output.cpu().numpy().squeeze() | |
return mask | |
def load_pytorch_model(model_path): | |
""" | |
Load the PyTorch model for prediction | |
Args: | |
model_path: Path to the PyTorch model | |
Returns: | |
PredictionEngine instance | |
""" | |
return PredictionEngine(model_path, use_onnx=False) | |
def load_onnx_model(model_path, input_size=256): | |
""" | |
Load the ONNX model for prediction | |
Args: | |
model_path: Path to the ONNX model | |
input_size: Input size for the model | |
Returns: | |
PredictionEngine instance | |
""" | |
return PredictionEngine(model_path, use_onnx=True, input_size=input_size) | |
# For backwards compatibility | |
def predict(model, image): | |
""" | |
Legacy function for prediction | |
Args: | |
model: Model instance | |
image: Input image | |
Returns: | |
Predicted mask | |
""" | |
if isinstance(model, PredictionEngine): | |
return model.predict(image) | |
engine = PredictionEngine(use_onnx=True) | |
engine.model = model | |
return engine.predict(image) |