Spaces:
Sleeping
Sleeping
add codebase
Browse files- app.py +143 -0
- predict.py +165 -0
- requirements.txt +13 -0
- utils/helpers.py +73 -0
- utils/model.py +143 -0
- utils/onnx_converter.py +99 -0
- utils/preprocess.py +36 -0
app.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from predict import load_onnx_model
|
8 |
+
from utils.helpers import calculate_deforestation_metrics, create_overlay
|
9 |
+
|
10 |
+
torch.classes.__path__ = []
|
11 |
+
|
12 |
+
# Set page config
|
13 |
+
st.set_page_config(page_title="Deforestation Detection", page_icon="🌳", layout="wide")
|
14 |
+
|
15 |
+
# Set constants
|
16 |
+
MODEL_INPUT_SIZE = 256 # The size our model expects
|
17 |
+
|
18 |
+
# Load ONNX model
|
19 |
+
@st.cache_resource
|
20 |
+
def load_cached_onnx_model():
|
21 |
+
model_path = "models/deforestation_model.onnx"
|
22 |
+
return load_onnx_model(model_path, input_size=MODEL_INPUT_SIZE)
|
23 |
+
|
24 |
+
def process_image(model, image):
|
25 |
+
"""Process a single image and return results"""
|
26 |
+
# Save original image dimensions for display
|
27 |
+
orig_height, orig_width = image.shape[:2]
|
28 |
+
|
29 |
+
# Make prediction
|
30 |
+
mask = model.predict(image)
|
31 |
+
|
32 |
+
# Resize mask back to original dimensions for display
|
33 |
+
display_mask = cv2.resize(mask, (orig_width, orig_height))
|
34 |
+
|
35 |
+
# Create binary mask for visualization
|
36 |
+
binary_mask = (display_mask > 0.5).astype(np.uint8) * 255
|
37 |
+
|
38 |
+
# Create colored overlay
|
39 |
+
overlay = create_overlay(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), display_mask)
|
40 |
+
|
41 |
+
# Calculate metrics
|
42 |
+
metrics = calculate_deforestation_metrics(mask)
|
43 |
+
|
44 |
+
return binary_mask, overlay, metrics
|
45 |
+
|
46 |
+
def main():
|
47 |
+
# App title and description
|
48 |
+
st.title("🌳 Deforestation Detection")
|
49 |
+
st.markdown(
|
50 |
+
"""
|
51 |
+
This app detects areas of deforestation in satellite or aerial images of forests.
|
52 |
+
Upload an image to get started!
|
53 |
+
"""
|
54 |
+
)
|
55 |
+
|
56 |
+
# Model info
|
57 |
+
st.info(
|
58 |
+
f"⚙️ Model optimized for {MODEL_INPUT_SIZE}x{MODEL_INPUT_SIZE} pixel images using ONNX runtime"
|
59 |
+
)
|
60 |
+
|
61 |
+
# Load model
|
62 |
+
try:
|
63 |
+
model = load_cached_onnx_model()
|
64 |
+
except Exception as e:
|
65 |
+
st.error(f"Error loading model: {e}")
|
66 |
+
st.info(
|
67 |
+
"Make sure you have converted your PyTorch model to ONNX format using the utils/onnx_converter.py script."
|
68 |
+
)
|
69 |
+
st.code(
|
70 |
+
"python -m utils.onnx_converter models/best_model_100.pth models/deforestation_model.onnx"
|
71 |
+
)
|
72 |
+
return
|
73 |
+
|
74 |
+
# File uploader for images
|
75 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
76 |
+
|
77 |
+
if uploaded_file is not None:
|
78 |
+
# Load image
|
79 |
+
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
|
80 |
+
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
81 |
+
|
82 |
+
# Display original image
|
83 |
+
st.subheader("Original Image")
|
84 |
+
st.image(
|
85 |
+
cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
|
86 |
+
caption="Uploaded Image",
|
87 |
+
use_container_width=True,
|
88 |
+
)
|
89 |
+
|
90 |
+
# Add a spinner while processing
|
91 |
+
with st.spinner("Processing..."):
|
92 |
+
try:
|
93 |
+
# Process image
|
94 |
+
binary_mask, overlay, metrics = process_image(model, image)
|
95 |
+
|
96 |
+
# Display results in columns
|
97 |
+
col1, col2 = st.columns(2)
|
98 |
+
|
99 |
+
with col1:
|
100 |
+
st.subheader("Segmentation Result")
|
101 |
+
st.image(
|
102 |
+
binary_mask,
|
103 |
+
caption="Forest Areas (White)",
|
104 |
+
use_container_width=True,
|
105 |
+
)
|
106 |
+
|
107 |
+
with col2:
|
108 |
+
st.subheader("Overlay Visualization")
|
109 |
+
st.image(
|
110 |
+
overlay,
|
111 |
+
caption="Green: Forest, Brown: Deforested",
|
112 |
+
use_container_width=True,
|
113 |
+
)
|
114 |
+
|
115 |
+
# Display metrics
|
116 |
+
st.subheader("Deforestation Analysis")
|
117 |
+
|
118 |
+
# Create metrics cards
|
119 |
+
metrics_col1, metrics_col2, metrics_col3 = st.columns(3)
|
120 |
+
|
121 |
+
with metrics_col1:
|
122 |
+
st.metric(
|
123 |
+
label="Forest Coverage",
|
124 |
+
value=f"{metrics['forest_percentage']:.1f}%",
|
125 |
+
)
|
126 |
+
|
127 |
+
with metrics_col2:
|
128 |
+
st.metric(
|
129 |
+
label="Deforested Area",
|
130 |
+
value=f"{metrics['deforested_percentage']:.1f}%",
|
131 |
+
)
|
132 |
+
|
133 |
+
with metrics_col3:
|
134 |
+
st.metric(
|
135 |
+
label="Deforestation Level",
|
136 |
+
value=metrics["deforestation_level"],
|
137 |
+
)
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
st.error(f"Error during processing: {e}")
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
main()
|
predict.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import onnxruntime as ort
|
5 |
+
from utils.preprocess import preprocess_image
|
6 |
+
from utils.model import load_model
|
7 |
+
|
8 |
+
|
9 |
+
class PredictionEngine:
|
10 |
+
def __init__(self, model_path=None, use_onnx=True, input_size=256):
|
11 |
+
"""
|
12 |
+
Initialize the prediction engine
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_path: Path to the model file (PyTorch or ONNX)
|
16 |
+
use_onnx: Whether to use ONNX runtime for inference
|
17 |
+
input_size: Input size for the model (default is 256)
|
18 |
+
"""
|
19 |
+
self.use_onnx = use_onnx
|
20 |
+
self.input_size = input_size
|
21 |
+
|
22 |
+
if model_path:
|
23 |
+
if use_onnx:
|
24 |
+
self.model = self._load_onnx_model(model_path)
|
25 |
+
else:
|
26 |
+
self.model = load_model(model_path)
|
27 |
+
else:
|
28 |
+
self.model = None
|
29 |
+
|
30 |
+
def _load_onnx_model(self, model_path):
|
31 |
+
"""
|
32 |
+
Load an ONNX model
|
33 |
+
|
34 |
+
Args:
|
35 |
+
model_path: Path to the ONNX model
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
ONNX Runtime InferenceSession
|
39 |
+
"""
|
40 |
+
# Try with CUDA first, fall back to CPU if needed
|
41 |
+
try:
|
42 |
+
session = ort.InferenceSession(
|
43 |
+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
|
44 |
+
)
|
45 |
+
print("ONNX model loaded with CUDA support")
|
46 |
+
return session
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Could not load ONNX model with CUDA, falling back to CPU: {e}")
|
49 |
+
session = ort.InferenceSession(
|
50 |
+
model_path, providers=["CPUExecutionProvider"]
|
51 |
+
)
|
52 |
+
print("ONNX model loaded with CPU support")
|
53 |
+
return session
|
54 |
+
|
55 |
+
def preprocess(self, image):
|
56 |
+
"""
|
57 |
+
Preprocess an image for prediction
|
58 |
+
|
59 |
+
Args:
|
60 |
+
image: Input image (numpy array)
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Processed image suitable for the model
|
64 |
+
"""
|
65 |
+
# Keep the original image for reference
|
66 |
+
self.original_shape = image.shape[:2]
|
67 |
+
|
68 |
+
# Preprocess image
|
69 |
+
if self.use_onnx:
|
70 |
+
# For ONNX, we need to ensure the input is exactly the expected size
|
71 |
+
tensor = preprocess_image(image, img_size=self.input_size)
|
72 |
+
return tensor.numpy()
|
73 |
+
else:
|
74 |
+
# For PyTorch
|
75 |
+
return preprocess_image(image, img_size=self.input_size)
|
76 |
+
|
77 |
+
def predict(self, image):
|
78 |
+
"""
|
79 |
+
Make a prediction on an image
|
80 |
+
|
81 |
+
Args:
|
82 |
+
image: Input image (numpy array)
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Predicted mask
|
86 |
+
"""
|
87 |
+
if self.model is None:
|
88 |
+
raise ValueError("Model not loaded. Initialize with a valid model path.")
|
89 |
+
|
90 |
+
# Preprocess the image
|
91 |
+
processed_input = self.preprocess(image)
|
92 |
+
|
93 |
+
# Run inference
|
94 |
+
if self.use_onnx:
|
95 |
+
# Get input and output names
|
96 |
+
input_name = self.model.get_inputs()[0].name
|
97 |
+
output_name = self.model.get_outputs()[0].name
|
98 |
+
|
99 |
+
# Run ONNX inference
|
100 |
+
outputs = self.model.run([output_name], {input_name: processed_input})
|
101 |
+
|
102 |
+
# Apply sigmoid to output
|
103 |
+
mask = 1 / (1 + np.exp(-outputs[0].squeeze()))
|
104 |
+
else:
|
105 |
+
# PyTorch inference
|
106 |
+
with torch.no_grad():
|
107 |
+
# Move to device
|
108 |
+
device = next(self.model.parameters()).device
|
109 |
+
processed_input = processed_input.to(device)
|
110 |
+
|
111 |
+
# Forward pass
|
112 |
+
output = self.model(processed_input)
|
113 |
+
output = torch.sigmoid(output)
|
114 |
+
|
115 |
+
# Convert to numpy
|
116 |
+
mask = output.cpu().numpy().squeeze()
|
117 |
+
|
118 |
+
return mask
|
119 |
+
|
120 |
+
|
121 |
+
def load_pytorch_model(model_path):
|
122 |
+
"""
|
123 |
+
Load the PyTorch model for prediction
|
124 |
+
|
125 |
+
Args:
|
126 |
+
model_path: Path to the PyTorch model
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
PredictionEngine instance
|
130 |
+
"""
|
131 |
+
return PredictionEngine(model_path, use_onnx=False)
|
132 |
+
|
133 |
+
|
134 |
+
def load_onnx_model(model_path, input_size=256):
|
135 |
+
"""
|
136 |
+
Load the ONNX model for prediction
|
137 |
+
|
138 |
+
Args:
|
139 |
+
model_path: Path to the ONNX model
|
140 |
+
input_size: Input size for the model
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
PredictionEngine instance
|
144 |
+
"""
|
145 |
+
return PredictionEngine(model_path, use_onnx=True, input_size=input_size)
|
146 |
+
|
147 |
+
|
148 |
+
# For backwards compatibility
|
149 |
+
def predict(model, image):
|
150 |
+
"""
|
151 |
+
Legacy function for prediction
|
152 |
+
|
153 |
+
Args:
|
154 |
+
model: Model instance
|
155 |
+
image: Input image
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Predicted mask
|
159 |
+
"""
|
160 |
+
if isinstance(model, PredictionEngine):
|
161 |
+
return model.predict(image)
|
162 |
+
|
163 |
+
engine = PredictionEngine(use_onnx=True)
|
164 |
+
engine.model = model
|
165 |
+
return engine.predict(image)
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
opencv-python
|
5 |
+
albumentations
|
6 |
+
numpy
|
7 |
+
Pillow
|
8 |
+
scikit-image
|
9 |
+
scikit-learn
|
10 |
+
matplotlib
|
11 |
+
onnxruntime
|
12 |
+
onnxruntime-gpu
|
13 |
+
onnx
|
utils/helpers.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def calculate_deforestation_metrics(mask, threshold=0.5):
|
6 |
+
"""
|
7 |
+
Calculate deforestation metrics from the predicted mask
|
8 |
+
|
9 |
+
Args:
|
10 |
+
mask: Predicted mask (numpy array)
|
11 |
+
threshold: Threshold to binarize the mask (default is 0.5)
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
Dictionary containing deforestation metrics
|
15 |
+
"""
|
16 |
+
# Binarize the mask
|
17 |
+
binary_mask = (mask > threshold).astype(np.uint8)
|
18 |
+
|
19 |
+
# Calculate pixel counts
|
20 |
+
total_pixels = binary_mask.size
|
21 |
+
forest_pixels = np.sum(binary_mask)
|
22 |
+
deforested_pixels = total_pixels - forest_pixels
|
23 |
+
|
24 |
+
# Calculate percentages
|
25 |
+
forest_percentage = (forest_pixels / total_pixels) * 100
|
26 |
+
deforested_percentage = (deforested_pixels / total_pixels) * 100
|
27 |
+
|
28 |
+
# Determine deforestation level
|
29 |
+
if deforested_percentage < 20:
|
30 |
+
level = "Low"
|
31 |
+
elif deforested_percentage < 50:
|
32 |
+
level = "Medium"
|
33 |
+
else:
|
34 |
+
level = "High"
|
35 |
+
|
36 |
+
return {
|
37 |
+
"forest_pixels": forest_pixels,
|
38 |
+
"deforested_pixels": deforested_pixels,
|
39 |
+
"forest_percentage": forest_percentage,
|
40 |
+
"deforested_percentage": deforested_percentage,
|
41 |
+
"deforestation_level": level,
|
42 |
+
}
|
43 |
+
|
44 |
+
|
45 |
+
def create_overlay(original_image, mask, threshold=0.5, alpha=0.5):
|
46 |
+
"""
|
47 |
+
Create a visualization by overlaying the mask on the original image
|
48 |
+
|
49 |
+
Args:
|
50 |
+
original_image: Original RGB image
|
51 |
+
mask: Predicted mask
|
52 |
+
threshold: Threshold to binarize the mask
|
53 |
+
alpha: Opacity of the overlay
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
Overlay image
|
57 |
+
"""
|
58 |
+
# Resize mask to match original image if needed
|
59 |
+
if original_image.shape[:2] != mask.shape[:2]:
|
60 |
+
mask = cv2.resize(mask, (original_image.shape[1], original_image.shape[0]))
|
61 |
+
|
62 |
+
# Create binary mask
|
63 |
+
binary_mask = (mask > threshold).astype(np.uint8) * 255
|
64 |
+
|
65 |
+
# Create a colored mask (green for forest, red for deforested)
|
66 |
+
colored_mask = np.zeros_like(original_image)
|
67 |
+
colored_mask[binary_mask == 255] = [0, 255, 0] # Green for forest
|
68 |
+
colored_mask[binary_mask == 0] = [150, 75, 0] # Brown for deforested
|
69 |
+
|
70 |
+
# Create overlay
|
71 |
+
overlay = cv2.addWeighted(original_image, 1 - alpha, colored_mask, alpha, 0)
|
72 |
+
|
73 |
+
return overlay
|
utils/model.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class AttentionGate(nn.Module):
|
8 |
+
def __init__(self, F_g, F_l, F_int):
|
9 |
+
super(AttentionGate, self).__init__()
|
10 |
+
self.W_g = nn.Sequential(
|
11 |
+
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
12 |
+
nn.BatchNorm2d(F_int),
|
13 |
+
)
|
14 |
+
|
15 |
+
self.W_x = nn.Sequential(
|
16 |
+
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
17 |
+
nn.BatchNorm2d(F_int),
|
18 |
+
)
|
19 |
+
|
20 |
+
self.psi = nn.Sequential(
|
21 |
+
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
22 |
+
nn.BatchNorm2d(1),
|
23 |
+
nn.Sigmoid(),
|
24 |
+
)
|
25 |
+
|
26 |
+
self.relu = nn.ReLU(inplace=True)
|
27 |
+
|
28 |
+
def forward(self, g, x):
|
29 |
+
g1 = self.W_g(g)
|
30 |
+
x1 = self.W_x(x)
|
31 |
+
psi = self.relu(g1 + x1)
|
32 |
+
psi = self.psi(psi)
|
33 |
+
|
34 |
+
return x * psi
|
35 |
+
|
36 |
+
|
37 |
+
class DoubleConv(nn.Module):
|
38 |
+
def __init__(self, in_channels, out_channels):
|
39 |
+
super(DoubleConv, self).__init__()
|
40 |
+
self.conv = nn.Sequential(
|
41 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
42 |
+
nn.BatchNorm2d(out_channels),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
45 |
+
nn.BatchNorm2d(out_channels),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.conv(x)
|
51 |
+
|
52 |
+
|
53 |
+
class AttentionUNet(nn.Module):
|
54 |
+
def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
|
55 |
+
super(AttentionUNet, self).__init__()
|
56 |
+
self.ups = nn.ModuleList()
|
57 |
+
self.downs = nn.ModuleList()
|
58 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
59 |
+
self.attention_gates = nn.ModuleList()
|
60 |
+
|
61 |
+
# Down part of U-Net
|
62 |
+
for feature in features:
|
63 |
+
self.downs.append(DoubleConv(in_channels, feature))
|
64 |
+
in_channels = feature
|
65 |
+
|
66 |
+
# Up part of U-Net
|
67 |
+
for feature in reversed(features):
|
68 |
+
self.ups.append(
|
69 |
+
nn.ConvTranspose2d(
|
70 |
+
feature * 2,
|
71 |
+
feature,
|
72 |
+
kernel_size=2,
|
73 |
+
stride=2,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
# Attention Gate
|
77 |
+
self.attention_gates.append(
|
78 |
+
AttentionGate(F_g=feature, F_l=feature, F_int=feature // 2)
|
79 |
+
)
|
80 |
+
|
81 |
+
self.ups.append(DoubleConv(feature * 2, feature))
|
82 |
+
|
83 |
+
# Bottleneck
|
84 |
+
self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
|
85 |
+
|
86 |
+
# Final Conv
|
87 |
+
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
skip_connections = []
|
91 |
+
|
92 |
+
# Encoder path
|
93 |
+
for down in self.downs:
|
94 |
+
x = down(x)
|
95 |
+
skip_connections.append(x)
|
96 |
+
x = self.pool(x)
|
97 |
+
|
98 |
+
x = self.bottleneck(x)
|
99 |
+
skip_connections = skip_connections[::-1] # Reverse to use from back
|
100 |
+
|
101 |
+
# Decoder path
|
102 |
+
for idx in range(0, len(self.ups), 2):
|
103 |
+
x = self.ups[idx](x)
|
104 |
+
skip_connection = skip_connections[idx // 2]
|
105 |
+
|
106 |
+
# If sizes don't match
|
107 |
+
if x.shape != skip_connection.shape:
|
108 |
+
x = F.interpolate(x, size=skip_connection.shape[2:])
|
109 |
+
|
110 |
+
# Apply attention gate
|
111 |
+
skip_connection = self.attention_gates[idx // 2](g=x, x=skip_connection)
|
112 |
+
|
113 |
+
# Concatenate
|
114 |
+
concat_skip = torch.cat((skip_connection, x), dim=1)
|
115 |
+
x = self.ups[idx + 1](concat_skip)
|
116 |
+
|
117 |
+
# Final conv
|
118 |
+
return self.final_conv(x)
|
119 |
+
|
120 |
+
|
121 |
+
def load_model(model_path):
|
122 |
+
"""
|
123 |
+
Load the trained model
|
124 |
+
|
125 |
+
Args:
|
126 |
+
model_path: Path to the model weights
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
Loaded model
|
130 |
+
"""
|
131 |
+
# Define device
|
132 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
133 |
+
|
134 |
+
# Initialize model
|
135 |
+
model = AttentionUNet(in_channels=3, out_channels=1)
|
136 |
+
|
137 |
+
# Load model weights
|
138 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
139 |
+
|
140 |
+
# Set model to evaluation mode
|
141 |
+
model.eval()
|
142 |
+
|
143 |
+
return model.to(device)
|
utils/onnx_converter.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
from utils.model import AttentionUNet
|
5 |
+
import onnx
|
6 |
+
import onnxruntime as ort
|
7 |
+
|
8 |
+
|
9 |
+
def convert_to_onnx(pytorch_model_path, onnx_output_path, input_size=256):
|
10 |
+
"""
|
11 |
+
Convert a PyTorch model to ONNX format
|
12 |
+
|
13 |
+
Args:
|
14 |
+
pytorch_model_path: Path to the PyTorch model
|
15 |
+
onnx_output_path: Path to save the ONNX model
|
16 |
+
input_size: Input size for the model (default is 256x256)
|
17 |
+
"""
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
+
print(f"Device used for conversion: {device}")
|
20 |
+
|
21 |
+
model = AttentionUNet(in_channels=3, out_channels=1)
|
22 |
+
model.to(device)
|
23 |
+
model.load_state_dict(torch.load(pytorch_model_path, map_location=device))
|
24 |
+
|
25 |
+
model.eval()
|
26 |
+
|
27 |
+
# Create dummy input
|
28 |
+
dummy_input = torch.randn(1, 3, input_size, input_size, device=device)
|
29 |
+
|
30 |
+
# Export the model
|
31 |
+
torch.onnx.export(
|
32 |
+
model, # model being run
|
33 |
+
dummy_input, # model input (or a tuple for multiple inputs)
|
34 |
+
onnx_output_path, # where to save the model
|
35 |
+
export_params=True, # store the trained parameter weights inside the model file
|
36 |
+
opset_version=12, # the ONNX version to export the model to
|
37 |
+
do_constant_folding=True, # whether to execute constant folding for optimization
|
38 |
+
input_names=["input"],
|
39 |
+
output_names=["output"],
|
40 |
+
dynamic_axes={
|
41 |
+
"input": {0: "batch_size", 2: "height", 3: "width"}, # variable length axes
|
42 |
+
"output": {0: "batch_size", 2: "height", 3: "width"},
|
43 |
+
},
|
44 |
+
)
|
45 |
+
|
46 |
+
print(f"Model converted and saved to {onnx_output_path}")
|
47 |
+
|
48 |
+
verify_onnx_model(onnx_output_path, input_size)
|
49 |
+
|
50 |
+
|
51 |
+
def verify_onnx_model(onnx_model_path, input_size=256):
|
52 |
+
"""
|
53 |
+
Verify the ONNX model to ensure it was exported correctly
|
54 |
+
|
55 |
+
Args:
|
56 |
+
onnx_model_path: Path to the ONNX model
|
57 |
+
input_size: Input size used during export
|
58 |
+
"""
|
59 |
+
try:
|
60 |
+
onnx_model = onnx.load(onnx_model_path)
|
61 |
+
onnx.checker.check_model(onnx_model)
|
62 |
+
print("ONNX model is valid")
|
63 |
+
except Exception as e:
|
64 |
+
print(f"ONNX model validation failed: {e}")
|
65 |
+
return False
|
66 |
+
|
67 |
+
try:
|
68 |
+
session = ort.InferenceSession(
|
69 |
+
onnx_model_path, providers=["CPUExecutionProvider"]
|
70 |
+
)
|
71 |
+
|
72 |
+
input_data = np.random.rand(1, 3, input_size, input_size).astype(np.float32)
|
73 |
+
|
74 |
+
# Get input and output names
|
75 |
+
input_name = session.get_inputs()[0].name
|
76 |
+
output_name = session.get_outputs()[0].name
|
77 |
+
|
78 |
+
# Run inference
|
79 |
+
outputs = session.run([output_name], {input_name: input_data})
|
80 |
+
|
81 |
+
print(f"ONNX model inference test passed. Output shape: {outputs[0].shape}")
|
82 |
+
return True
|
83 |
+
except Exception as e:
|
84 |
+
print(f"ONNX model inference test failed: {e}")
|
85 |
+
return False
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
if len(sys.argv) < 3:
|
90 |
+
print(
|
91 |
+
"Usage: python -m utils.onnx_converter <pytorch_model_path> <onnx_output_path> [input_size]"
|
92 |
+
)
|
93 |
+
sys.exit(1)
|
94 |
+
|
95 |
+
pytorch_model_path = sys.argv[1]
|
96 |
+
onnx_output_path = sys.argv[2]
|
97 |
+
input_size = int(sys.argv[3]) if len(sys.argv) > 3 else 256
|
98 |
+
|
99 |
+
convert_to_onnx(pytorch_model_path, onnx_output_path, input_size)
|
utils/preprocess.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import albumentations as A
|
5 |
+
from albumentations.pytorch import ToTensorV2
|
6 |
+
|
7 |
+
|
8 |
+
def preprocess_image(image, img_size=256):
|
9 |
+
"""
|
10 |
+
Preprocess the input image for model prediction
|
11 |
+
|
12 |
+
Args:
|
13 |
+
image: Input image (numpy array)
|
14 |
+
img_size: Size to resize image to (model was trained on 256x256)
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
Preprocessed image tensor
|
18 |
+
"""
|
19 |
+
# Define the transformation
|
20 |
+
transform = A.Compose(
|
21 |
+
[
|
22 |
+
A.Resize(height=img_size, width=img_size),
|
23 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
24 |
+
ToTensorV2(),
|
25 |
+
]
|
26 |
+
)
|
27 |
+
|
28 |
+
# Apply the transformation
|
29 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
30 |
+
augmented = transform(image=image)
|
31 |
+
image_tensor = augmented["image"]
|
32 |
+
|
33 |
+
# Add batch dimension
|
34 |
+
image_tensor = image_tensor.unsqueeze(0)
|
35 |
+
|
36 |
+
return image_tensor
|