|
import os |
|
import cv2 |
|
import time |
|
import torch |
|
import imageio |
|
import tifffile |
|
import numpy as np |
|
import slidingwindow |
|
import rasterio as rio |
|
import geopandas as gpd |
|
from shapely.geometry import Polygon |
|
from rasterio import mask as riomask |
|
from torch.utils.data import DataLoader |
|
from SemanticModel.visualization import generate_color_mapping |
|
from SemanticModel.image_preprocessing import get_validation_augmentations |
|
from SemanticModel.data_loader import InferenceDataset, StreamingDataset |
|
from SemanticModel.utilities import calc_image_size, convert_coordinates |
|
|
|
class PredictionPipeline: |
|
def __init__(self, model_config, device=None): |
|
self.config = model_config |
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.classes = ['background'] + model_config.classes if model_config.background_flag else model_config.classes |
|
self.colors = generate_color_mapping(len(self.classes)) |
|
self.model = model_config.model.to(self.device) |
|
self.model.eval() |
|
|
|
def _preprocess_image(self, image_path, target_size=None): |
|
"""Preprocesses single image for prediction.""" |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
height, width = image.shape[:2] |
|
|
|
target_size = target_size or max(height, width) |
|
test_height, test_width = calc_image_size(image, target_size) |
|
|
|
augmentation = get_validation_augmentations(test_width, test_height) |
|
image = augmentation(image=image)['image'] |
|
image = self.config.preprocessing(image=image)['image'] |
|
|
|
return image, (height, width) |
|
|
|
def predict_single_image(self, image_path, target_size=None, output_dir=None, |
|
format='integer', save_output=True): |
|
"""Generates prediction for a single image.""" |
|
image, original_dims = self._preprocess_image(image_path, target_size) |
|
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
prediction = self.model.predict(x_tensor) |
|
|
|
if self.config.n_classes > 1: |
|
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) |
|
else: |
|
prediction = prediction.squeeze().cpu().numpy().round() |
|
|
|
|
|
if prediction.shape[:2] != original_dims: |
|
prediction = cv2.resize(prediction, original_dims[::-1], |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
prediction = self._format_prediction(prediction, format) |
|
|
|
if save_output: |
|
self._save_prediction(prediction, image_path, output_dir, format) |
|
|
|
return prediction |
|
|
|
def predict_directory(self, input_dir, target_size=None, output_dir=None, |
|
fixed_size=True, format='integer'): |
|
"""Generates predictions for all images in directory.""" |
|
output_dir = output_dir or os.path.join(input_dir, 'predictions') |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
dataset = InferenceDataset( |
|
input_dir, |
|
classes=self.classes, |
|
augmentation=get_validation_augmentations( |
|
target_size, target_size, fixed_size=fixed_size |
|
) if target_size else None, |
|
preprocessing=self.config.preprocessing |
|
) |
|
|
|
total_images = len(dataset) |
|
start_time = time.time() |
|
|
|
for idx in range(total_images): |
|
if (idx + 1) % 10 == 0 or idx == total_images - 1: |
|
elapsed = time.time() - start_time |
|
print(f'\rProcessed {idx+1}/{total_images} images in {elapsed:.1f}s', |
|
end='') |
|
|
|
image, height, width = dataset[idx] |
|
filename = dataset.filenames[idx] |
|
|
|
x_tensor = torch.from_numpy(image).to(self.device).unsqueeze(0) |
|
with torch.no_grad(): |
|
prediction = self.model.predict(x_tensor) |
|
|
|
if self.config.n_classes > 1: |
|
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) |
|
else: |
|
prediction = prediction.squeeze().cpu().numpy().round() |
|
|
|
if prediction.shape != (height, width): |
|
prediction = cv2.resize(prediction, (width, height), |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
prediction = self._format_prediction(prediction, format) |
|
self._save_prediction(prediction, filename, output_dir, format) |
|
|
|
print(f'\nPredictions saved to: {output_dir}') |
|
return output_dir |
|
|
|
def predict_raster( |
|
self, |
|
raster_path, |
|
tile_size=1024, |
|
overlap=0.175, |
|
boundary_path=None, |
|
output_path=None, |
|
format='integer' |
|
): |
|
""" |
|
Processes large raster images using a tiling approach. For each tile: |
|
1) Optionally checks a boundary mask (if provided) to skip tiles outside an ROI. |
|
2) Applies augmentations/preprocessing, then runs the model prediction. |
|
3) Resizes back to the tile's original size if necessary (e.g., after aug). |
|
4) Merges the tile predictions into a final 'pred_raster' (with confidence blending). |
|
|
|
Args: |
|
raster_path (str): Path to the large raster image (GeoTIFF). |
|
tile_size (int): Dimensions of each tile (default 1024). |
|
overlap (float): Overlap fraction between tiles (default 0.175). |
|
boundary_path (str): Path to shapefile/geojson for boundary region (optional). |
|
output_path (str): Path to save prediction (optional). |
|
format (str): 'integer' for integer mask, 'color' for RGB, etc. |
|
|
|
Returns: |
|
pred_raster (np.ndarray): 2D or 3D numpy array with the final merged prediction. |
|
profile (dict): Raster profile/metadata for georeferencing or saving. |
|
""" |
|
|
|
print("Loading raster...") |
|
with rio.open(raster_path) as src: |
|
|
|
raster = src.read() |
|
raster = np.moveaxis(raster, 0, 2) |
|
raster = raster[:, :, :3] |
|
profile = src.profile |
|
transform = src.transform |
|
|
|
boundary_geom = None |
|
if boundary_path: |
|
boundary = gpd.read_file(boundary_path) |
|
boundary = boundary.to_crs(profile['crs']) |
|
boundary_geom = boundary.iloc[0].geometry |
|
|
|
print("Generating tiles...") |
|
tiles = slidingwindow.generate( |
|
raster, |
|
slidingwindow.DimOrder.HeightWidthChannel, |
|
tile_size, |
|
overlap |
|
) |
|
|
|
|
|
|
|
|
|
pred_raster = np.zeros_like(raster[:, :, 0], dtype='uint8') |
|
confidence = np.zeros_like(pred_raster, dtype=np.float32) |
|
|
|
|
|
aug = get_validation_augmentations(tile_size, tile_size, fixed_size=False) |
|
|
|
|
|
|
|
|
|
for idx, tile in enumerate(tiles): |
|
if (idx + 1) % 10 == 0 or idx == len(tiles) - 1: |
|
print(f"\rProcessed {idx + 1}/{len(tiles)} tiles", end="") |
|
|
|
|
|
bounds = tile.indices() |
|
|
|
|
|
tile_image = raster[bounds[0], bounds[1]] |
|
|
|
|
|
|
|
|
|
|
|
if tile_image.shape[0] == 0 or tile_image.shape[1] == 0: |
|
|
|
continue |
|
|
|
|
|
if boundary_geom is not None: |
|
corners = [ |
|
convert_coordinates(transform, bounds[1].start, bounds[0].start), |
|
convert_coordinates(transform, bounds[1].stop, bounds[0].start), |
|
convert_coordinates(transform, bounds[1].stop, bounds[0].stop), |
|
convert_coordinates(transform, bounds[1].start, bounds[0].stop) |
|
] |
|
poly = Polygon(corners) |
|
if not poly.intersects(boundary_geom): |
|
|
|
continue |
|
|
|
|
|
processed = aug(image=tile_image)['image'] |
|
|
|
|
|
|
|
if processed.shape[0] == 0 or processed.shape[1] == 0: |
|
|
|
continue |
|
|
|
|
|
processed = self.config.preprocessing(image=processed)['image'] |
|
x_tensor = torch.from_numpy(processed).to(self.device).unsqueeze(0) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
prediction = self.model.predict(x_tensor) |
|
|
|
prediction = prediction.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if prediction.ndim == 3 and prediction.shape[0] == self.config.n_classes: |
|
|
|
tile_pred = np.argmax(prediction, axis=0).astype(np.uint8) |
|
tile_conf = np.max(prediction, axis=0).astype(np.float32) |
|
else: |
|
|
|
if prediction.ndim == 3: |
|
prediction = prediction[0] |
|
tile_conf = np.abs(prediction - 0.5).astype(np.float32) |
|
tile_pred = np.round(prediction).astype(np.uint8) |
|
|
|
orig_hw = tile_image.shape[:2] |
|
if tile_pred.shape != orig_hw: |
|
|
|
|
|
|
|
tile_pred_float = tile_pred.astype(np.float32) |
|
tile_conf_float = tile_conf.astype(np.float32) |
|
|
|
|
|
cv2_size = (orig_hw[1], orig_hw[0]) |
|
if cv2_size[0] == 0 or cv2_size[1] == 0: |
|
|
|
continue |
|
|
|
|
|
tile_pred_resized = cv2.resize( |
|
tile_pred_float, cv2_size, interpolation=cv2.INTER_NEAREST |
|
) |
|
tile_conf_resized = cv2.resize( |
|
tile_conf_float, cv2_size, interpolation=cv2.INTER_LINEAR |
|
) |
|
|
|
|
|
tile_pred = np.round(tile_pred_resized).astype(np.uint8) |
|
tile_conf = tile_conf_resized.astype(np.float32) |
|
|
|
|
|
|
|
|
|
existing_conf = confidence[bounds[0], bounds[1]] |
|
existing_pred = pred_raster[bounds[0], bounds[1]] |
|
|
|
mask = tile_conf > existing_conf |
|
existing_pred[mask] = tile_pred[mask] |
|
existing_conf[mask] = tile_conf[mask] |
|
|
|
pred_raster[bounds[0], bounds[1]] = existing_pred |
|
confidence[bounds[0], bounds[1]] = existing_conf |
|
|
|
|
|
print("\n Finished all tiles") |
|
|
|
|
|
pred_raster = self._format_prediction(pred_raster, format) |
|
|
|
|
|
if output_path or boundary_path: |
|
self._save_raster_prediction( |
|
pred_raster, |
|
raster_path, |
|
output_path, |
|
profile, |
|
boundary_geom if boundary_path else None |
|
) |
|
|
|
return pred_raster, profile |
|
|
|
def _format_prediction(self, prediction, format): |
|
"""Formats prediction according to specified output type.""" |
|
if format == 'integer': |
|
return prediction.astype('uint8') |
|
elif format == 'color': |
|
return self._apply_color_mapping(prediction) |
|
else: |
|
raise ValueError(f"Unsupported format: {format}") |
|
|
|
def _save_prediction(self, prediction, source_path, output_dir, format): |
|
"""Saves prediction to disk.""" |
|
filename = os.path.splitext(os.path.basename(source_path))[0] |
|
output_path = os.path.join(output_dir, f"{filename}_pred.png") |
|
cv2.imwrite(output_path, prediction) |
|
|
|
|
|
def _save_raster_prediction(self, prediction, source_path, output_path, |
|
profile, boundary=None): |
|
"""Saves raster prediction with geospatial information.""" |
|
output_path = output_path or source_path.replace( |
|
os.path.splitext(source_path)[1], '_predicted.tif' |
|
) |
|
|
|
profile.update( |
|
dtype='uint8', |
|
count=3 if prediction.ndim == 3 else 1 |
|
) |
|
|
|
with rio.open(output_path, 'w', **profile) as dst: |
|
if prediction.ndim == 3: |
|
for i in range(3): |
|
dst.write(prediction[:,:,i], i+1) |
|
else: |
|
dst.write(prediction, 1) |
|
|
|
if boundary: |
|
with rio.open(output_path) as src: |
|
cropped, transform = riomask.mask(src, [boundary], crop=True) |
|
profile.update( |
|
height=cropped.shape[1], |
|
width=cropped.shape[2], |
|
transform=transform |
|
) |
|
|
|
os.remove(output_path) |
|
with rio.open(output_path, 'w', **profile) as dst: |
|
dst.write(cropped) |
|
|
|
print(f'\nPrediction saved to: {output_path}') |
|
|
|
def predict_video_frames(self, input_dir, target_size=None, output_dir=None): |
|
"""Processes video frames with specialized visualization.""" |
|
output_dir = output_dir or os.path.join(input_dir, 'predictions') |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
dataset = StreamingDataset( |
|
input_dir, |
|
classes=self.classes, |
|
augmentation=get_validation_augmentations( |
|
target_size, target_size |
|
) if target_size else None, |
|
preprocessing=self.config.preprocessing |
|
) |
|
|
|
image = cv2.imread(dataset.image_paths[0]) |
|
height, width = image.shape[:2] |
|
|
|
white = 255 * np.ones((height, width)) |
|
black = np.zeros_like(white) |
|
red = np.dstack((white, black, black)) |
|
blue = np.dstack((black, black, white)) |
|
|
|
|
|
rotated_red = np.rot90(red) |
|
rotated_blue = np.rot90(blue) |
|
|
|
total_frames = len(dataset) |
|
start_time = time.time() |
|
|
|
for idx in range(total_frames): |
|
if (idx + 1) % 10 == 0 or idx == total_frames - 1: |
|
elapsed = time.time() - start_time |
|
print(f'\rProcessed {idx+1}/{total_frames} frames in {elapsed:.1f}s', end='') |
|
|
|
frame, height, width = dataset[idx] |
|
filename = dataset.filenames[idx] |
|
|
|
x_tensor = torch.from_numpy(frame).to(self.device).unsqueeze(0) |
|
with torch.no_grad(): |
|
prediction = self.model.predict(x_tensor) |
|
|
|
if self.config.n_classes > 1: |
|
prediction = np.argmax(prediction.squeeze().cpu().numpy(), axis=0) |
|
masks = [prediction == i for i in range(1, self.config.n_classes)] |
|
else: |
|
prediction = prediction.squeeze().cpu().numpy().round() |
|
masks = [prediction == 1] |
|
|
|
if prediction.shape != (height, width): |
|
prediction = cv2.resize(prediction, (width, height), |
|
interpolation=cv2.INTER_NEAREST) |
|
|
|
original = cv2.imread(os.path.join(input_dir, filename)) |
|
original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB) |
|
|
|
try: |
|
for i, mask in enumerate(masks): |
|
color = red if i == 0 else blue |
|
rotated_color = rotated_red if i == 0 else rotated_blue |
|
try: |
|
original[mask,:] = 0.45*original[mask,:] + 0.55*color[mask,:] |
|
except: |
|
original[mask,:] = 0.45*original[mask,:] + 0.55*rotated_color[mask,:] |
|
except: |
|
print(f"\nWarning: Error processing frame {filename}") |
|
continue |
|
|
|
output_path = os.path.join(output_dir, filename) |
|
imageio.imwrite(output_path, original, quality=100) |
|
|
|
print(f'\nProcessed frames saved to: {output_dir}') |
|
return output_dir |
|
|
|
def _apply_color_mapping(self, prediction): |
|
"""Applies color mapping to prediction.""" |
|
height, width = prediction.shape |
|
colored = np.zeros((height, width, 3), dtype='uint8') |
|
|
|
for i, class_name in enumerate(self.classes): |
|
if class_name.lower() == 'background': |
|
continue |
|
color = self.colors[i] |
|
colored[prediction == i] = color |
|
|
|
return colored |