|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
|
|
import cv2 |
|
import fastremap |
|
import numpy as np |
|
import PIL |
|
import tifffile |
|
import torch |
|
import torch.nn.functional as F |
|
from cellpose.dynamics import compute_masks, masks_to_flows |
|
from cellpose.metrics import _intersection_over_union, _true_positive |
|
from monai.apps import get_logger |
|
from monai.data import MetaTensor |
|
from monai.transforms import MapTransform |
|
from monai.utils import ImageMetaKey, convert_to_dst_type |
|
|
|
logger = get_logger("VistaCell") |
|
|
|
|
|
class LoadTiffd(MapTransform): |
|
def __call__(self, data): |
|
d = dict(data) |
|
for key in self.key_iterator(d): |
|
filename = d[key] |
|
|
|
extension = os.path.splitext(filename)[1][1:] |
|
image_size = None |
|
|
|
if extension in ["tif", "tiff"]: |
|
img_array = tifffile.imread(filename) |
|
image_size = img_array.shape |
|
if len(img_array.shape) == 3 and img_array.shape[-1] <= 3: |
|
img_array = np.transpose(img_array, (2, 0, 1)) |
|
else: |
|
img_array = np.array(PIL.Image.open(filename)) |
|
image_size = img_array.shape |
|
if len(img_array.shape) == 3: |
|
img_array = np.transpose(img_array, (2, 0, 1)) |
|
|
|
if len(img_array.shape) not in [2, 3]: |
|
raise ValueError( |
|
"Unsupported image dimensions, filename " + str(filename) + " shape " + str(img_array.shape) |
|
) |
|
|
|
if len(img_array.shape) == 2: |
|
img_array = img_array[np.newaxis] |
|
|
|
if key == "label": |
|
if img_array.shape[0] > 1: |
|
print( |
|
f"Strange case, label with several channels {filename} shape {img_array.shape}, keeping only first" |
|
) |
|
img_array = img_array[[0]] |
|
|
|
elif key == "image": |
|
if img_array.shape[0] == 1: |
|
img_array = np.repeat(img_array, repeats=3, axis=0) |
|
elif img_array.shape[0] == 2: |
|
print( |
|
f"Strange case, image with 2 channels {filename} shape {img_array.shape}, appending first channel to make 3" |
|
) |
|
img_array = np.stack( |
|
(img_array[0], img_array[1], img_array[0]), axis=0 |
|
) |
|
elif img_array.shape[0] > 3: |
|
print(f"Strange case, image with >3 channels, {filename} shape {img_array.shape}, keeping first 3") |
|
img_array = img_array[:3] |
|
|
|
meta_data = {ImageMetaKey.FILENAME_OR_OBJ: filename, ImageMetaKey.SPATIAL_SHAPE: image_size} |
|
d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data) |
|
|
|
return d |
|
|
|
|
|
class SaveTiffd(MapTransform): |
|
def __init__(self, output_dir, data_root_dir="/", nested_folder=False, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.output_dir = output_dir |
|
self.data_root_dir = data_root_dir |
|
self.nested_folder = nested_folder |
|
|
|
def set_data_root_dir(self, data_root_dir): |
|
self.data_root_dir = data_root_dir |
|
|
|
def __call__(self, data): |
|
d = dict(data) |
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
for key in self.key_iterator(d): |
|
seg = d[key] |
|
filename = seg.meta[ImageMetaKey.FILENAME_OR_OBJ] |
|
|
|
basename = os.path.splitext(os.path.basename(filename))[0] |
|
|
|
if self.nested_folder: |
|
reldir = os.path.relpath(os.path.dirname(filename), self.data_root_dir) |
|
outdir = os.path.join(self.output_dir, reldir) |
|
os.makedirs(outdir, exist_ok=True) |
|
else: |
|
outdir = self.output_dir |
|
|
|
outname = os.path.join(outdir, basename + ".tif") |
|
|
|
label = seg.cpu().numpy() |
|
lm = label.max() |
|
if lm <= 255: |
|
label = label.astype(np.uint8) |
|
elif lm <= 65535: |
|
label = label.astype(np.uint16) |
|
else: |
|
label = label.astype(np.uint32) |
|
|
|
tifffile.imwrite(outname, label) |
|
|
|
print(f"Saving {outname} shape {label.shape} max {label.max()} dtype {label.dtype}") |
|
|
|
return d |
|
|
|
|
|
class LabelsToFlows(MapTransform): |
|
|
|
|
|
|
|
def __init__(self, flow_key, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
self.flow_key = flow_key |
|
|
|
def __call__(self, data): |
|
d = dict(data) |
|
for key in self.key_iterator(d): |
|
label = d[key].int().numpy() |
|
|
|
label = fastremap.renumber(label, in_place=True)[0] |
|
veci = masks_to_flows(label[0], device=None) |
|
|
|
flows = np.concatenate((label > 0.5, veci), axis=0).astype(np.float32) |
|
flows = convert_to_dst_type(flows, d[key], dtype=torch.float, device=d[key].device)[0] |
|
d[self.flow_key] = flows |
|
|
|
|
|
return d |
|
|
|
|
|
class LogitsToLabels: |
|
def __call__(self, logits, filename=None): |
|
device = logits.device |
|
logits = logits.float().cpu().numpy() |
|
dp = logits[1:] |
|
cellprob = logits[0] |
|
|
|
try: |
|
pred_mask, p = compute_masks( |
|
dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=device |
|
) |
|
except RuntimeError as e: |
|
logger.warning(f"compute_masks failed on GPU retrying on CPU {logits.shape} file {filename} {e}") |
|
pred_mask, p = compute_masks( |
|
dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=None |
|
) |
|
|
|
return pred_mask, p |
|
|
|
|
|
class LogitsToLabelsd(MapTransform): |
|
def __call__(self, data): |
|
d = dict(data) |
|
f = LogitsToLabels() |
|
for key in self.key_iterator(d): |
|
pred_mask, p = f(d[key]) |
|
d[key] = pred_mask |
|
d[f"{key}_centroids"] = p |
|
return d |
|
|
|
|
|
class SaveTiffExd(MapTransform): |
|
def __init__(self, output_dir, output_ext=".png", output_postfix="seg", image_key="image", *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
self.output_dir = output_dir |
|
self.output_ext = output_ext |
|
self.output_postfix = output_postfix |
|
self.image_key = image_key |
|
|
|
def to_polygons(self, contours): |
|
polygons = [] |
|
for contour in contours: |
|
if len(contour) < 3: |
|
continue |
|
polygons.append(np.squeeze(contour).astype(int).tolist()) |
|
return polygons |
|
|
|
def __call__(self, data): |
|
d = dict(data) |
|
|
|
output_dir = d.get("output_dir", self.output_dir) |
|
output_ext = d.get("output_ext", self.output_ext) |
|
overlayed_masks = d.get("overlayed_masks", False) |
|
output_contours = d.get("output_contours", False) |
|
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
img = d.get(self.image_key, None) |
|
filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None |
|
image_size = img.meta.get(ImageMetaKey.SPATIAL_SHAPE) if img is not None else None |
|
basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask" |
|
logger.info(f"File: {filename}; Base: {basename}") |
|
|
|
for key in self.key_iterator(d): |
|
label = d[key] |
|
output_filename = f"{basename}{'_' + self.output_postfix if self.output_postfix else ''}{output_ext}" |
|
output_filepath = os.path.join(output_dir, output_filename) |
|
lm = label.max() |
|
logger.info(f"Mask Shape: {label.shape}; Instances: {lm}") |
|
|
|
if lm <= 255: |
|
label = label.astype(np.uint8) |
|
elif lm <= 65535: |
|
label = label.astype(np.uint16) |
|
else: |
|
label = label.astype(np.uint32) |
|
|
|
tifffile.imwrite(output_filepath, label) |
|
logger.info(f"Saving {output_filepath}") |
|
|
|
polygons = [] |
|
if overlayed_masks: |
|
logger.info(f"Overlay Masks: Reading original Image: {filename}") |
|
image = cv2.imread(filename) |
|
mask = cv2.imread(output_filepath, 0) |
|
|
|
for i in range(1, np.max(mask)): |
|
m = np.zeros_like(mask) |
|
m[mask == i] = 1 |
|
color = np.random.choice(range(256), size=3).tolist() |
|
contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
|
polygons.extend(self.to_polygons(contours)) |
|
cv2.drawContours(image, contours, -1, color, 1) |
|
cv2.imwrite(output_filepath, image) |
|
logger.info(f"Overlay Masks: Saving {output_filepath}") |
|
else: |
|
label = cv2.convertScaleAbs(label, alpha=255.0 / label.max()) |
|
contours, _ = cv2.findContours(label, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
|
polygons.extend(self.to_polygons(contours)) |
|
|
|
meta_json = {"image_size": image_size, "contours": len(polygons)} |
|
with open(os.path.join(output_dir, "meta.json"), "w") as fp: |
|
json.dump(meta_json, fp, indent=2) |
|
|
|
if output_contours: |
|
logger.info(f"Total Polygons: {len(polygons)}") |
|
with open(os.path.join(output_dir, "contours.json"), "w") as fp: |
|
json.dump({"count": len(polygons), "contours": polygons}, fp, indent=2) |
|
|
|
return d |
|
|
|
|
|
|
|
class CellLoss: |
|
def __call__(self, y_pred, y): |
|
loss = 0.5 * F.mse_loss(y_pred[:, 1:], 5 * y[:, 1:]) + F.binary_cross_entropy_with_logits( |
|
y_pred[:, [0]], y[:, [0]] |
|
) |
|
return loss |
|
|
|
|
|
|
|
class CellAcc: |
|
def __call__(self, mask_pred, mask_true): |
|
if isinstance(mask_true, torch.Tensor): |
|
mask_true = mask_true.cpu().numpy() |
|
|
|
if isinstance(mask_pred, torch.Tensor): |
|
mask_pred = mask_pred.cpu().numpy() |
|
|
|
|
|
|
|
|
|
iou = _intersection_over_union(mask_true, mask_pred)[1:, 1:] |
|
tp = _true_positive(iou, th=0.5) |
|
|
|
fp = np.max(mask_pred) - tp |
|
fn = np.max(mask_true) - tp |
|
ap = tp / (tp + fp + fn) |
|
|
|
|
|
return ap |
|
|