vista2d / scripts /components.py
project-monai's picture
Upload vista2d version 0.3.1
fd4ffa6 verified
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import cv2
import fastremap
import numpy as np
import PIL
import tifffile
import torch
import torch.nn.functional as F
from cellpose.dynamics import compute_masks, masks_to_flows
from cellpose.metrics import _intersection_over_union, _true_positive
from monai.apps import get_logger
from monai.data import MetaTensor
from monai.transforms import MapTransform
from monai.utils import ImageMetaKey, convert_to_dst_type
logger = get_logger("VistaCell")
class LoadTiffd(MapTransform):
def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
filename = d[key]
extension = os.path.splitext(filename)[1][1:]
image_size = None
if extension in ["tif", "tiff"]:
img_array = tifffile.imread(filename) # use tifffile for tif images
image_size = img_array.shape
if len(img_array.shape) == 3 and img_array.shape[-1] <= 3:
img_array = np.transpose(img_array, (2, 0, 1)) # channels first without transpose
else:
img_array = np.array(PIL.Image.open(filename)) # PIL for all other images (png, jpeg)
image_size = img_array.shape
if len(img_array.shape) == 3:
img_array = np.transpose(img_array, (2, 0, 1)) # channels first
if len(img_array.shape) not in [2, 3]:
raise ValueError(
"Unsupported image dimensions, filename " + str(filename) + " shape " + str(img_array.shape)
)
if len(img_array.shape) == 2:
img_array = img_array[np.newaxis] # add channels_first if no channel
if key == "label":
if img_array.shape[0] > 1:
print(
f"Strange case, label with several channels {filename} shape {img_array.shape}, keeping only first"
)
img_array = img_array[[0]]
elif key == "image":
if img_array.shape[0] == 1:
img_array = np.repeat(img_array, repeats=3, axis=0) # if grayscale, repeat as 3 channels
elif img_array.shape[0] == 2:
print(
f"Strange case, image with 2 channels {filename} shape {img_array.shape}, appending first channel to make 3"
)
img_array = np.stack(
(img_array[0], img_array[1], img_array[0]), axis=0
) # this should not happen, we got 2 channel input image
elif img_array.shape[0] > 3:
print(f"Strange case, image with >3 channels, {filename} shape {img_array.shape}, keeping first 3")
img_array = img_array[:3]
meta_data = {ImageMetaKey.FILENAME_OR_OBJ: filename, ImageMetaKey.SPATIAL_SHAPE: image_size}
d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
return d
class SaveTiffd(MapTransform):
def __init__(self, output_dir, data_root_dir="/", nested_folder=False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.output_dir = output_dir
self.data_root_dir = data_root_dir
self.nested_folder = nested_folder
def set_data_root_dir(self, data_root_dir):
self.data_root_dir = data_root_dir
def __call__(self, data):
d = dict(data)
os.makedirs(self.output_dir, exist_ok=True)
for key in self.key_iterator(d):
seg = d[key]
filename = seg.meta[ImageMetaKey.FILENAME_OR_OBJ]
basename = os.path.splitext(os.path.basename(filename))[0]
if self.nested_folder:
reldir = os.path.relpath(os.path.dirname(filename), self.data_root_dir)
outdir = os.path.join(self.output_dir, reldir)
os.makedirs(outdir, exist_ok=True)
else:
outdir = self.output_dir
outname = os.path.join(outdir, basename + ".tif")
label = seg.cpu().numpy()
lm = label.max()
if lm <= 255:
label = label.astype(np.uint8)
elif lm <= 65535:
label = label.astype(np.uint16)
else:
label = label.astype(np.uint32)
tifffile.imwrite(outname, label)
print(f"Saving {outname} shape {label.shape} max {label.max()} dtype {label.dtype}")
return d
class LabelsToFlows(MapTransform):
# based on dynamics labels_to_flows()
# created a 3 channel output (foreground, flowx, flowy) and saves under flow (new) key
def __init__(self, flow_key, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.flow_key = flow_key
def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
label = d[key].int().numpy()
label = fastremap.renumber(label, in_place=True)[0]
veci = masks_to_flows(label[0], device=None)
flows = np.concatenate((label > 0.5, veci), axis=0).astype(np.float32)
flows = convert_to_dst_type(flows, d[key], dtype=torch.float, device=d[key].device)[0]
d[self.flow_key] = flows
# meta_data = {ImageMetaKey.FILENAME_OR_OBJ : filename}
# d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
return d
class LogitsToLabels:
def __call__(self, logits, filename=None):
device = logits.device
logits = logits.float().cpu().numpy()
dp = logits[1:] # vectors
cellprob = logits[0] # foreground prob (logit)
try:
pred_mask, p = compute_masks(
dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=device
)
except RuntimeError as e:
logger.warning(f"compute_masks failed on GPU retrying on CPU {logits.shape} file {filename} {e}")
pred_mask, p = compute_masks(
dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=None
)
return pred_mask, p
class LogitsToLabelsd(MapTransform):
def __call__(self, data):
d = dict(data)
f = LogitsToLabels()
for key in self.key_iterator(d):
pred_mask, p = f(d[key])
d[key] = pred_mask
d[f"{key}_centroids"] = p
return d
class SaveTiffExd(MapTransform):
def __init__(self, output_dir, output_ext=".png", output_postfix="seg", image_key="image", *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.output_dir = output_dir
self.output_ext = output_ext
self.output_postfix = output_postfix
self.image_key = image_key
def to_polygons(self, contours):
polygons = []
for contour in contours:
if len(contour) < 3:
continue
polygons.append(np.squeeze(contour).astype(int).tolist())
return polygons
def __call__(self, data):
d = dict(data)
output_dir = d.get("output_dir", self.output_dir)
output_ext = d.get("output_ext", self.output_ext)
overlayed_masks = d.get("overlayed_masks", False)
output_contours = d.get("output_contours", False)
os.makedirs(self.output_dir, exist_ok=True)
img = d.get(self.image_key, None)
filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None
image_size = img.meta.get(ImageMetaKey.SPATIAL_SHAPE) if img is not None else None
basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask"
logger.info(f"File: {filename}; Base: {basename}")
for key in self.key_iterator(d):
label = d[key]
output_filename = f"{basename}{'_' + self.output_postfix if self.output_postfix else ''}{output_ext}"
output_filepath = os.path.join(output_dir, output_filename)
lm = label.max()
logger.info(f"Mask Shape: {label.shape}; Instances: {lm}")
if lm <= 255:
label = label.astype(np.uint8)
elif lm <= 65535:
label = label.astype(np.uint16)
else:
label = label.astype(np.uint32)
tifffile.imwrite(output_filepath, label)
logger.info(f"Saving {output_filepath}")
polygons = []
if overlayed_masks:
logger.info(f"Overlay Masks: Reading original Image: {filename}")
image = cv2.imread(filename)
mask = cv2.imread(output_filepath, 0)
for i in range(1, np.max(mask)):
m = np.zeros_like(mask)
m[mask == i] = 1
color = np.random.choice(range(256), size=3).tolist()
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
polygons.extend(self.to_polygons(contours))
cv2.drawContours(image, contours, -1, color, 1)
cv2.imwrite(output_filepath, image)
logger.info(f"Overlay Masks: Saving {output_filepath}")
else:
label = cv2.convertScaleAbs(label, alpha=255.0 / label.max())
contours, _ = cv2.findContours(label, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
polygons.extend(self.to_polygons(contours))
meta_json = {"image_size": image_size, "contours": len(polygons)}
with open(os.path.join(output_dir, "meta.json"), "w") as fp:
json.dump(meta_json, fp, indent=2)
if output_contours:
logger.info(f"Total Polygons: {len(polygons)}")
with open(os.path.join(output_dir, "contours.json"), "w") as fp:
json.dump({"count": len(polygons), "contours": polygons}, fp, indent=2)
return d
# Loss (adopted from Cellpose)
class CellLoss:
def __call__(self, y_pred, y):
loss = 0.5 * F.mse_loss(y_pred[:, 1:], 5 * y[:, 1:]) + F.binary_cross_entropy_with_logits(
y_pred[:, [0]], y[:, [0]]
)
return loss
# Accuracy (adopted from Cellpose)
class CellAcc:
def __call__(self, mask_pred, mask_true):
if isinstance(mask_true, torch.Tensor):
mask_true = mask_true.cpu().numpy()
if isinstance(mask_pred, torch.Tensor):
mask_pred = mask_pred.cpu().numpy()
# print("CellAcc mask_true", mask_true.shape, 'max', np.max(mask_true), ",
# "'mask_pred', mask_pred.shape, 'max', np.max(mask_pred) )
iou = _intersection_over_union(mask_true, mask_pred)[1:, 1:]
tp = _true_positive(iou, th=0.5)
fp = np.max(mask_pred) - tp
fn = np.max(mask_true) - tp
ap = tp / (tp + fp + fn)
# print("CellAcc ap", ap, 'tp', tp, 'fp', fp, 'fn', fn)
return ap