smokxy commited on
Commit
8c38d83
·
1 Parent(s): 03ad78b

add codebase

Browse files
Files changed (7) hide show
  1. app.py +143 -0
  2. predict.py +165 -0
  3. requirements.txt +13 -0
  4. utils/helpers.py +73 -0
  5. utils/model.py +143 -0
  6. utils/onnx_converter.py +99 -0
  7. utils/preprocess.py +36 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import torch
7
+ from predict import load_onnx_model
8
+ from utils.helpers import calculate_deforestation_metrics, create_overlay
9
+
10
+ torch.classes.__path__ = []
11
+
12
+ # Set page config
13
+ st.set_page_config(page_title="Deforestation Detection", page_icon="🌳", layout="wide")
14
+
15
+ # Set constants
16
+ MODEL_INPUT_SIZE = 256 # The size our model expects
17
+
18
+ # Load ONNX model
19
+ @st.cache_resource
20
+ def load_cached_onnx_model():
21
+ model_path = "models/deforestation_model.onnx"
22
+ return load_onnx_model(model_path, input_size=MODEL_INPUT_SIZE)
23
+
24
+ def process_image(model, image):
25
+ """Process a single image and return results"""
26
+ # Save original image dimensions for display
27
+ orig_height, orig_width = image.shape[:2]
28
+
29
+ # Make prediction
30
+ mask = model.predict(image)
31
+
32
+ # Resize mask back to original dimensions for display
33
+ display_mask = cv2.resize(mask, (orig_width, orig_height))
34
+
35
+ # Create binary mask for visualization
36
+ binary_mask = (display_mask > 0.5).astype(np.uint8) * 255
37
+
38
+ # Create colored overlay
39
+ overlay = create_overlay(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), display_mask)
40
+
41
+ # Calculate metrics
42
+ metrics = calculate_deforestation_metrics(mask)
43
+
44
+ return binary_mask, overlay, metrics
45
+
46
+ def main():
47
+ # App title and description
48
+ st.title("🌳 Deforestation Detection")
49
+ st.markdown(
50
+ """
51
+ This app detects areas of deforestation in satellite or aerial images of forests.
52
+ Upload an image to get started!
53
+ """
54
+ )
55
+
56
+ # Model info
57
+ st.info(
58
+ f"⚙️ Model optimized for {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE} pixel images using ONNX runtime"
59
+ )
60
+
61
+ # Load model
62
+ try:
63
+ model = load_cached_onnx_model()
64
+ except Exception as e:
65
+ st.error(f"Error loading model: {e}")
66
+ st.info(
67
+ "Make sure you have converted your PyTorch model to ONNX format using the utils/onnx_converter.py script."
68
+ )
69
+ st.code(
70
+ "python -m utils.onnx_converter models/best_model_100.pth models/deforestation_model.onnx"
71
+ )
72
+ return
73
+
74
+ # File uploader for images
75
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
76
+
77
+ if uploaded_file is not None:
78
+ # Load image
79
+ file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
80
+ image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
81
+
82
+ # Display original image
83
+ st.subheader("Original Image")
84
+ st.image(
85
+ cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
86
+ caption="Uploaded Image",
87
+ use_container_width=True,
88
+ )
89
+
90
+ # Add a spinner while processing
91
+ with st.spinner("Processing..."):
92
+ try:
93
+ # Process image
94
+ binary_mask, overlay, metrics = process_image(model, image)
95
+
96
+ # Display results in columns
97
+ col1, col2 = st.columns(2)
98
+
99
+ with col1:
100
+ st.subheader("Segmentation Result")
101
+ st.image(
102
+ binary_mask,
103
+ caption="Forest Areas (White)",
104
+ use_container_width=True,
105
+ )
106
+
107
+ with col2:
108
+ st.subheader("Overlay Visualization")
109
+ st.image(
110
+ overlay,
111
+ caption="Green: Forest, Brown: Deforested",
112
+ use_container_width=True,
113
+ )
114
+
115
+ # Display metrics
116
+ st.subheader("Deforestation Analysis")
117
+
118
+ # Create metrics cards
119
+ metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
120
+
121
+ with metrics_col1:
122
+ st.metric(
123
+ label="Forest Coverage",
124
+ value=f"{metrics['forest_percentage']:.1f}%",
125
+ )
126
+
127
+ with metrics_col2:
128
+ st.metric(
129
+ label="Deforested Area",
130
+ value=f"{metrics['deforested_percentage']:.1f}%",
131
+ )
132
+
133
+ with metrics_col3:
134
+ st.metric(
135
+ label="Deforestation Level",
136
+ value=metrics["deforestation_level"],
137
+ )
138
+
139
+ except Exception as e:
140
+ st.error(f"Error during processing: {e}")
141
+
142
+ if __name__ == "__main__":
143
+ main()
predict.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import onnxruntime as ort
5
+ from utils.preprocess import preprocess_image
6
+ from utils.model import load_model
7
+
8
+
9
+ class PredictionEngine:
10
+ def __init__(self, model_path=None, use_onnx=True, input_size=256):
11
+ """
12
+ Initialize the prediction engine
13
+
14
+ Args:
15
+ model_path: Path to the model file (PyTorch or ONNX)
16
+ use_onnx: Whether to use ONNX runtime for inference
17
+ input_size: Input size for the model (default is 256)
18
+ """
19
+ self.use_onnx = use_onnx
20
+ self.input_size = input_size
21
+
22
+ if model_path:
23
+ if use_onnx:
24
+ self.model = self._load_onnx_model(model_path)
25
+ else:
26
+ self.model = load_model(model_path)
27
+ else:
28
+ self.model = None
29
+
30
+ def _load_onnx_model(self, model_path):
31
+ """
32
+ Load an ONNX model
33
+
34
+ Args:
35
+ model_path: Path to the ONNX model
36
+
37
+ Returns:
38
+ ONNX Runtime InferenceSession
39
+ """
40
+ # Try with CUDA first, fall back to CPU if needed
41
+ try:
42
+ session = ort.InferenceSession(
43
+ model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
44
+ )
45
+ print("ONNX model loaded with CUDA support")
46
+ return session
47
+ except Exception as e:
48
+ print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
49
+ session = ort.InferenceSession(
50
+ model_path, providers=["CPUExecutionProvider"]
51
+ )
52
+ print("ONNX model loaded with CPU support")
53
+ return session
54
+
55
+ def preprocess(self, image):
56
+ """
57
+ Preprocess an image for prediction
58
+
59
+ Args:
60
+ image: Input image (numpy array)
61
+
62
+ Returns:
63
+ Processed image suitable for the model
64
+ """
65
+ # Keep the original image for reference
66
+ self.original_shape = image.shape[:2]
67
+
68
+ # Preprocess image
69
+ if self.use_onnx:
70
+ # For ONNX, we need to ensure the input is exactly the expected size
71
+ tensor = preprocess_image(image, img_size=self.input_size)
72
+ return tensor.numpy()
73
+ else:
74
+ # For PyTorch
75
+ return preprocess_image(image, img_size=self.input_size)
76
+
77
+ def predict(self, image):
78
+ """
79
+ Make a prediction on an image
80
+
81
+ Args:
82
+ image: Input image (numpy array)
83
+
84
+ Returns:
85
+ Predicted mask
86
+ """
87
+ if self.model is None:
88
+ raise ValueError("Model not loaded. Initialize with a valid model path.")
89
+
90
+ # Preprocess the image
91
+ processed_input = self.preprocess(image)
92
+
93
+ # Run inference
94
+ if self.use_onnx:
95
+ # Get input and output names
96
+ input_name = self.model.get_inputs()[0].name
97
+ output_name = self.model.get_outputs()[0].name
98
+
99
+ # Run ONNX inference
100
+ outputs = self.model.run([output_name], {input_name: processed_input})
101
+
102
+ # Apply sigmoid to output
103
+ mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
104
+ else:
105
+ # PyTorch inference
106
+ with torch.no_grad():
107
+ # Move to device
108
+ device = next(self.model.parameters()).device
109
+ processed_input = processed_input.to(device)
110
+
111
+ # Forward pass
112
+ output = self.model(processed_input)
113
+ output = torch.sigmoid(output)
114
+
115
+ # Convert to numpy
116
+ mask = output.cpu().numpy().squeeze()
117
+
118
+ return mask
119
+
120
+
121
+ def load_pytorch_model(model_path):
122
+ """
123
+ Load the PyTorch model for prediction
124
+
125
+ Args:
126
+ model_path: Path to the PyTorch model
127
+
128
+ Returns:
129
+ PredictionEngine instance
130
+ """
131
+ return PredictionEngine(model_path, use_onnx=False)
132
+
133
+
134
+ def load_onnx_model(model_path, input_size=256):
135
+ """
136
+ Load the ONNX model for prediction
137
+
138
+ Args:
139
+ model_path: Path to the ONNX model
140
+ input_size: Input size for the model
141
+
142
+ Returns:
143
+ PredictionEngine instance
144
+ """
145
+ return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
146
+
147
+
148
+ # For backwards compatibility
149
+ def predict(model, image):
150
+ """
151
+ Legacy function for prediction
152
+
153
+ Args:
154
+ model: Model instance
155
+ image: Input image
156
+
157
+ Returns:
158
+ Predicted mask
159
+ """
160
+ if isinstance(model, PredictionEngine):
161
+ return model.predict(image)
162
+
163
+ engine = PredictionEngine(use_onnx=True)
164
+ engine.model = model
165
+ return engine.predict(image)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ opencv-python
5
+ albumentations
6
+ numpy
7
+ Pillow
8
+ scikit-image
9
+ scikit-learn
10
+ matplotlib
11
+ onnxruntime
12
+ onnxruntime-gpu
13
+ onnx
utils/helpers.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+
5
+ def calculate_deforestation_metrics(mask, threshold=0.5):
6
+ """
7
+ Calculate deforestation metrics from the predicted mask
8
+
9
+ Args:
10
+ mask: Predicted mask (numpy array)
11
+ threshold: Threshold to binarize the mask (default is 0.5)
12
+
13
+ Returns:
14
+ Dictionary containing deforestation metrics
15
+ """
16
+ # Binarize the mask
17
+ binary_mask = (mask > threshold).astype(np.uint8)
18
+
19
+ # Calculate pixel counts
20
+ total_pixels = binary_mask.size
21
+ forest_pixels = np.sum(binary_mask)
22
+ deforested_pixels = total_pixels - forest_pixels
23
+
24
+ # Calculate percentages
25
+ forest_percentage = (forest_pixels / total_pixels) * 100
26
+ deforested_percentage = (deforested_pixels / total_pixels) * 100
27
+
28
+ # Determine deforestation level
29
+ if deforested_percentage < 20:
30
+ level = "Low"
31
+ elif deforested_percentage < 50:
32
+ level = "Medium"
33
+ else:
34
+ level = "High"
35
+
36
+ return {
37
+ "forest_pixels": forest_pixels,
38
+ "deforested_pixels": deforested_pixels,
39
+ "forest_percentage": forest_percentage,
40
+ "deforested_percentage": deforested_percentage,
41
+ "deforestation_level": level,
42
+ }
43
+
44
+
45
+ def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
46
+ """
47
+ Create a visualization by overlaying the mask on the original image
48
+
49
+ Args:
50
+ original_image: Original RGB image
51
+ mask: Predicted mask
52
+ threshold: Threshold to binarize the mask
53
+ alpha: Opacity of the overlay
54
+
55
+ Returns:
56
+ Overlay image
57
+ """
58
+ # Resize mask to match original image if needed
59
+ if original_image.shape[:2] != mask.shape[:2]:
60
+ mask = cv2.resize(mask, (original_image.shape[1], original_image.shape[0]))
61
+
62
+ # Create binary mask
63
+ binary_mask = (mask > threshold).astype(np.uint8) * 255
64
+
65
+ # Create a colored mask (green for forest, red for deforested)
66
+ colored_mask = np.zeros_like(original_image)
67
+ colored_mask[binary_mask == 255] = [0, 255, 0] # Green for forest
68
+ colored_mask[binary_mask == 0] = [150, 75, 0] # Brown for deforested
69
+
70
+ # Create overlay
71
+ overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
72
+
73
+ return overlay
utils/model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class AttentionGate(nn.Module):
8
+ def __init__(self, F_g, F_l, F_int):
9
+ super(AttentionGate, self).__init__()
10
+ self.W_g = nn.Sequential(
11
+ nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
12
+ nn.BatchNorm2d(F_int),
13
+ )
14
+
15
+ self.W_x = nn.Sequential(
16
+ nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
17
+ nn.BatchNorm2d(F_int),
18
+ )
19
+
20
+ self.psi = nn.Sequential(
21
+ nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
22
+ nn.BatchNorm2d(1),
23
+ nn.Sigmoid(),
24
+ )
25
+
26
+ self.relu = nn.ReLU(inplace=True)
27
+
28
+ def forward(self, g, x):
29
+ g1 = self.W_g(g)
30
+ x1 = self.W_x(x)
31
+ psi = self.relu(g1 + x1)
32
+ psi = self.psi(psi)
33
+
34
+ return x * psi
35
+
36
+
37
+ class DoubleConv(nn.Module):
38
+ def __init__(self, in_channels, out_channels):
39
+ super(DoubleConv, self).__init__()
40
+ self.conv = nn.Sequential(
41
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
42
+ nn.BatchNorm2d(out_channels),
43
+ nn.ReLU(inplace=True),
44
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
45
+ nn.BatchNorm2d(out_channels),
46
+ nn.ReLU(inplace=True),
47
+ )
48
+
49
+ def forward(self, x):
50
+ return self.conv(x)
51
+
52
+
53
+ class AttentionUNet(nn.Module):
54
+ def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
55
+ super(AttentionUNet, self).__init__()
56
+ self.ups = nn.ModuleList()
57
+ self.downs = nn.ModuleList()
58
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
59
+ self.attention_gates = nn.ModuleList()
60
+
61
+ # Down part of U-Net
62
+ for feature in features:
63
+ self.downs.append(DoubleConv(in_channels, feature))
64
+ in_channels = feature
65
+
66
+ # Up part of U-Net
67
+ for feature in reversed(features):
68
+ self.ups.append(
69
+ nn.ConvTranspose2d(
70
+ feature * 2,
71
+ feature,
72
+ kernel_size=2,
73
+ stride=2,
74
+ )
75
+ )
76
+ # Attention Gate
77
+ self.attention_gates.append(
78
+ AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2)
79
+ )
80
+
81
+ self.ups.append(DoubleConv(feature * 2, feature))
82
+
83
+ # Bottleneck
84
+ self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
85
+
86
+ # Final Conv
87
+ self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
88
+
89
+ def forward(self, x):
90
+ skip_connections = []
91
+
92
+ # Encoder path
93
+ for down in self.downs:
94
+ x = down(x)
95
+ skip_connections.append(x)
96
+ x = self.pool(x)
97
+
98
+ x = self.bottleneck(x)
99
+ skip_connections = skip_connections[::-1] # Reverse to use from back
100
+
101
+ # Decoder path
102
+ for idx in range(0, len(self.ups), 2):
103
+ x = self.ups[idx](x)
104
+ skip_connection = skip_connections[idx // 2]
105
+
106
+ # If sizes don't match
107
+ if x.shape != skip_connection.shape:
108
+ x = F.interpolate(x, size=skip_connection.shape[2:])
109
+
110
+ # Apply attention gate
111
+ skip_connection = self.attention_gates[idx // 2](g=x, x=skip_connection)
112
+
113
+ # Concatenate
114
+ concat_skip = torch.cat((skip_connection, x), dim=1)
115
+ x = self.ups[idx + 1](concat_skip)
116
+
117
+ # Final conv
118
+ return self.final_conv(x)
119
+
120
+
121
+ def load_model(model_path):
122
+ """
123
+ Load the trained model
124
+
125
+ Args:
126
+ model_path: Path to the model weights
127
+
128
+ Returns:
129
+ Loaded model
130
+ """
131
+ # Define device
132
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
133
+
134
+ # Initialize model
135
+ model = AttentionUNet(in_channels=3, out_channels=1)
136
+
137
+ # Load model weights
138
+ model.load_state_dict(torch.load(model_path, map_location=device))
139
+
140
+ # Set model to evaluation mode
141
+ model.eval()
142
+
143
+ return model.to(device)
utils/onnx_converter.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+ import numpy as np
4
+ from utils.model import AttentionUNet
5
+ import onnx
6
+ import onnxruntime as ort
7
+
8
+
9
+ def convert_to_onnx(pytorch_model_path, onnx_output_path, input_size=256):
10
+ """
11
+ Convert a PyTorch model to ONNX format
12
+
13
+ Args:
14
+ pytorch_model_path: Path to the PyTorch model
15
+ onnx_output_path: Path to save the ONNX model
16
+ input_size: Input size for the model (default is 256x256)
17
+ """
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ print(f"Device used for conversion: {device}")
20
+
21
+ model = AttentionUNet(in_channels=3, out_channels=1)
22
+ model.to(device)
23
+ model.load_state_dict(torch.load(pytorch_model_path, map_location=device))
24
+
25
+ model.eval()
26
+
27
+ # Create dummy input
28
+ dummy_input = torch.randn(1, 3, input_size, input_size, device=device)
29
+
30
+ # Export the model
31
+ torch.onnx.export(
32
+ model, # model being run
33
+ dummy_input, # model input (or a tuple for multiple inputs)
34
+ onnx_output_path, # where to save the model
35
+ export_params=True, # store the trained parameter weights inside the model file
36
+ opset_version=12, # the ONNX version to export the model to
37
+ do_constant_folding=True, # whether to execute constant folding for optimization
38
+ input_names=["input"],
39
+ output_names=["output"],
40
+ dynamic_axes={
41
+ "input": {0: "batch_size", 2: "height", 3: "width"}, # variable length axes
42
+ "output": {0: "batch_size", 2: "height", 3: "width"},
43
+ },
44
+ )
45
+
46
+ print(f"Model converted and saved to {onnx_output_path}")
47
+
48
+ verify_onnx_model(onnx_output_path, input_size)
49
+
50
+
51
+ def verify_onnx_model(onnx_model_path, input_size=256):
52
+ """
53
+ Verify the ONNX model to ensure it was exported correctly
54
+
55
+ Args:
56
+ onnx_model_path: Path to the ONNX model
57
+ input_size: Input size used during export
58
+ """
59
+ try:
60
+ onnx_model = onnx.load(onnx_model_path)
61
+ onnx.checker.check_model(onnx_model)
62
+ print("ONNX model is valid")
63
+ except Exception as e:
64
+ print(f"ONNX model validation failed: {e}")
65
+ return False
66
+
67
+ try:
68
+ session = ort.InferenceSession(
69
+ onnx_model_path, providers=["CPUExecutionProvider"]
70
+ )
71
+
72
+ input_data = np.random.rand(1, 3, input_size, input_size).astype(np.float32)
73
+
74
+ # Get input and output names
75
+ input_name = session.get_inputs()[0].name
76
+ output_name = session.get_outputs()[0].name
77
+
78
+ # Run inference
79
+ outputs = session.run([output_name], {input_name: input_data})
80
+
81
+ print(f"ONNX model inference test passed. Output shape: {outputs[0].shape}")
82
+ return True
83
+ except Exception as e:
84
+ print(f"ONNX model inference test failed: {e}")
85
+ return False
86
+
87
+
88
+ if __name__ == "__main__":
89
+ if len(sys.argv) < 3:
90
+ print(
91
+ "Usage: python -m utils.onnx_converter <pytorch_model_path> <onnx_output_path> [input_size]"
92
+ )
93
+ sys.exit(1)
94
+
95
+ pytorch_model_path = sys.argv[1]
96
+ onnx_output_path = sys.argv[2]
97
+ input_size = int(sys.argv[3]) if len(sys.argv) > 3 else 256
98
+
99
+ convert_to_onnx(pytorch_model_path, onnx_output_path, input_size)
utils/preprocess.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import albumentations as A
5
+ from albumentations.pytorch import ToTensorV2
6
+
7
+
8
+ def preprocess_image(image, img_size=256):
9
+ """
10
+ Preprocess the input image for model prediction
11
+
12
+ Args:
13
+ image: Input image (numpy array)
14
+ img_size: Size to resize image to (model was trained on 256x256)
15
+
16
+ Returns:
17
+ Preprocessed image tensor
18
+ """
19
+ # Define the transformation
20
+ transform = A.Compose(
21
+ [
22
+ A.Resize(height=img_size, width=img_size),
23
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
24
+ ToTensorV2(),
25
+ ]
26
+ )
27
+
28
+ # Apply the transformation
29
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30
+ augmented = transform(image=image)
31
+ image_tensor = augmented["image"]
32
+
33
+ # Add batch dimension
34
+ image_tensor = image_tensor.unsqueeze(0)
35
+
36
+ return image_tensor