| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import json |
| | import os |
| |
|
| | import cv2 |
| | import fastremap |
| | import numpy as np |
| | import PIL |
| | import tifffile |
| | import torch |
| | import torch.nn.functional as F |
| | from cellpose.dynamics import compute_masks, masks_to_flows |
| | from cellpose.metrics import _intersection_over_union, _true_positive |
| | from monai.apps import get_logger |
| | from monai.data import MetaTensor |
| | from monai.transforms import MapTransform |
| | from monai.utils import ImageMetaKey, convert_to_dst_type |
| |
|
| | logger = get_logger("VistaCell") |
| |
|
| |
|
| | class LoadTiffd(MapTransform): |
| | def __call__(self, data): |
| | d = dict(data) |
| | for key in self.key_iterator(d): |
| | filename = d[key] |
| |
|
| | extension = os.path.splitext(filename)[1][1:] |
| | image_size = None |
| |
|
| | if extension in ["tif", "tiff"]: |
| | img_array = tifffile.imread(filename) |
| | image_size = img_array.shape |
| | if len(img_array.shape) == 3 and img_array.shape[-1] <= 3: |
| | img_array = np.transpose(img_array, (2, 0, 1)) |
| | else: |
| | img_array = np.array(PIL.Image.open(filename)) |
| | image_size = img_array.shape |
| | if len(img_array.shape) == 3: |
| | img_array = np.transpose(img_array, (2, 0, 1)) |
| |
|
| | if len(img_array.shape) not in [2, 3]: |
| | raise ValueError( |
| | "Unsupported image dimensions, filename " + str(filename) + " shape " + str(img_array.shape) |
| | ) |
| |
|
| | if len(img_array.shape) == 2: |
| | img_array = img_array[np.newaxis] |
| |
|
| | if key == "label": |
| | if img_array.shape[0] > 1: |
| | print( |
| | f"Strange case, label with several channels {filename} shape {img_array.shape}, keeping only first" |
| | ) |
| | img_array = img_array[[0]] |
| |
|
| | elif key == "image": |
| | if img_array.shape[0] == 1: |
| | img_array = np.repeat(img_array, repeats=3, axis=0) |
| | elif img_array.shape[0] == 2: |
| | print( |
| | f"Strange case, image with 2 channels {filename} shape {img_array.shape}, appending first channel to make 3" |
| | ) |
| | img_array = np.stack( |
| | (img_array[0], img_array[1], img_array[0]), axis=0 |
| | ) |
| | elif img_array.shape[0] > 3: |
| | print(f"Strange case, image with >3 channels, {filename} shape {img_array.shape}, keeping first 3") |
| | img_array = img_array[:3] |
| |
|
| | meta_data = {ImageMetaKey.FILENAME_OR_OBJ: filename, ImageMetaKey.SPATIAL_SHAPE: image_size} |
| | d[key] = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data) |
| |
|
| | return d |
| |
|
| |
|
| | class SaveTiffd(MapTransform): |
| | def __init__(self, output_dir, data_root_dir="/", nested_folder=False, *args, **kwargs) -> None: |
| | super().__init__(*args, **kwargs) |
| |
|
| | self.output_dir = output_dir |
| | self.data_root_dir = data_root_dir |
| | self.nested_folder = nested_folder |
| |
|
| | def set_data_root_dir(self, data_root_dir): |
| | self.data_root_dir = data_root_dir |
| |
|
| | def __call__(self, data): |
| | d = dict(data) |
| | os.makedirs(self.output_dir, exist_ok=True) |
| |
|
| | for key in self.key_iterator(d): |
| | seg = d[key] |
| | filename = seg.meta[ImageMetaKey.FILENAME_OR_OBJ] |
| |
|
| | basename = os.path.splitext(os.path.basename(filename))[0] |
| |
|
| | if self.nested_folder: |
| | reldir = os.path.relpath(os.path.dirname(filename), self.data_root_dir) |
| | outdir = os.path.join(self.output_dir, reldir) |
| | os.makedirs(outdir, exist_ok=True) |
| | else: |
| | outdir = self.output_dir |
| |
|
| | outname = os.path.join(outdir, basename + ".tif") |
| |
|
| | label = seg.cpu().numpy() |
| | lm = label.max() |
| | if lm <= 255: |
| | label = label.astype(np.uint8) |
| | elif lm <= 65535: |
| | label = label.astype(np.uint16) |
| | else: |
| | label = label.astype(np.uint32) |
| |
|
| | tifffile.imwrite(outname, label) |
| |
|
| | print(f"Saving {outname} shape {label.shape} max {label.max()} dtype {label.dtype}") |
| |
|
| | return d |
| |
|
| |
|
| | class LabelsToFlows(MapTransform): |
| | |
| | |
| |
|
| | def __init__(self, flow_key, *args, **kwargs) -> None: |
| | super().__init__(*args, **kwargs) |
| | self.flow_key = flow_key |
| |
|
| | def __call__(self, data): |
| | d = dict(data) |
| | for key in self.key_iterator(d): |
| | label = d[key].int().numpy() |
| |
|
| | label = fastremap.renumber(label, in_place=True)[0] |
| | veci = masks_to_flows(label[0], device=None) |
| |
|
| | flows = np.concatenate((label > 0.5, veci), axis=0).astype(np.float32) |
| | flows = convert_to_dst_type(flows, d[key], dtype=torch.float, device=d[key].device)[0] |
| | d[self.flow_key] = flows |
| | |
| | |
| | return d |
| |
|
| |
|
| | class LogitsToLabels: |
| | def __call__(self, logits, filename=None): |
| | device = logits.device |
| | logits = logits.float().cpu().numpy() |
| | dp = logits[1:] |
| | cellprob = logits[0] |
| |
|
| | try: |
| | pred_mask, p = compute_masks( |
| | dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=device |
| | ) |
| | except RuntimeError as e: |
| | logger.warning(f"compute_masks failed on GPU retrying on CPU {logits.shape} file {filename} {e}") |
| | pred_mask, p = compute_masks( |
| | dp, cellprob, niter=200, cellprob_threshold=0.4, flow_threshold=0.4, interp=True, device=None |
| | ) |
| |
|
| | return pred_mask, p |
| |
|
| |
|
| | class LogitsToLabelsd(MapTransform): |
| | def __call__(self, data): |
| | d = dict(data) |
| | f = LogitsToLabels() |
| | for key in self.key_iterator(d): |
| | pred_mask, p = f(d[key]) |
| | d[key] = pred_mask |
| | d[f"{key}_centroids"] = p |
| | return d |
| |
|
| |
|
| | class SaveTiffExd(MapTransform): |
| | def __init__(self, output_dir, output_ext=".png", output_postfix="seg", image_key="image", *args, **kwargs) -> None: |
| | super().__init__(*args, **kwargs) |
| |
|
| | self.output_dir = output_dir |
| | self.output_ext = output_ext |
| | self.output_postfix = output_postfix |
| | self.image_key = image_key |
| |
|
| | def to_polygons(self, contours): |
| | polygons = [] |
| | for contour in contours: |
| | if len(contour) < 3: |
| | continue |
| | polygons.append(np.squeeze(contour).astype(int).tolist()) |
| | return polygons |
| |
|
| | def __call__(self, data): |
| | d = dict(data) |
| |
|
| | output_dir = d.get("output_dir", self.output_dir) |
| | output_ext = d.get("output_ext", self.output_ext) |
| | overlayed_masks = d.get("overlayed_masks", False) |
| | output_contours = d.get("output_contours", False) |
| |
|
| | os.makedirs(self.output_dir, exist_ok=True) |
| |
|
| | img = d.get(self.image_key, None) |
| | filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None |
| | image_size = img.meta.get(ImageMetaKey.SPATIAL_SHAPE) if img is not None else None |
| | basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask" |
| | logger.info(f"File: {filename}; Base: {basename}") |
| |
|
| | for key in self.key_iterator(d): |
| | label = d[key] |
| | output_filename = f"{basename}{'_' + self.output_postfix if self.output_postfix else ''}{output_ext}" |
| | output_filepath = os.path.join(output_dir, output_filename) |
| | lm = label.max() |
| | logger.info(f"Mask Shape: {label.shape}; Instances: {lm}") |
| |
|
| | if lm <= 255: |
| | label = label.astype(np.uint8) |
| | elif lm <= 65535: |
| | label = label.astype(np.uint16) |
| | else: |
| | label = label.astype(np.uint32) |
| |
|
| | tifffile.imwrite(output_filepath, label) |
| | logger.info(f"Saving {output_filepath}") |
| |
|
| | polygons = [] |
| | if overlayed_masks: |
| | logger.info(f"Overlay Masks: Reading original Image: {filename}") |
| | image = cv2.imread(filename) |
| | mask = cv2.imread(output_filepath, 0) |
| |
|
| | for i in range(1, np.max(mask)): |
| | m = np.zeros_like(mask) |
| | m[mask == i] = 1 |
| | color = np.random.choice(range(256), size=3).tolist() |
| | contours, _ = cv2.findContours(m, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| | polygons.extend(self.to_polygons(contours)) |
| | cv2.drawContours(image, contours, -1, color, 1) |
| | cv2.imwrite(output_filepath, image) |
| | logger.info(f"Overlay Masks: Saving {output_filepath}") |
| | else: |
| | label = cv2.convertScaleAbs(label, alpha=255.0 / label.max()) |
| | contours, _ = cv2.findContours(label, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) |
| | polygons.extend(self.to_polygons(contours)) |
| |
|
| | meta_json = {"image_size": image_size, "contours": len(polygons)} |
| | with open(os.path.join(output_dir, "meta.json"), "w") as fp: |
| | json.dump(meta_json, fp, indent=2) |
| |
|
| | if output_contours: |
| | logger.info(f"Total Polygons: {len(polygons)}") |
| | with open(os.path.join(output_dir, "contours.json"), "w") as fp: |
| | json.dump({"count": len(polygons), "contours": polygons}, fp, indent=2) |
| |
|
| | return d |
| |
|
| |
|
| | |
| | class CellLoss: |
| | def __call__(self, y_pred, y): |
| | loss = 0.5 * F.mse_loss(y_pred[:, 1:], 5 * y[:, 1:]) + F.binary_cross_entropy_with_logits( |
| | y_pred[:, [0]], y[:, [0]] |
| | ) |
| | return loss |
| |
|
| |
|
| | |
| | class CellAcc: |
| | def __call__(self, mask_pred, mask_true): |
| | if isinstance(mask_true, torch.Tensor): |
| | mask_true = mask_true.cpu().numpy() |
| |
|
| | if isinstance(mask_pred, torch.Tensor): |
| | mask_pred = mask_pred.cpu().numpy() |
| |
|
| | |
| | |
| |
|
| | iou = _intersection_over_union(mask_true, mask_pred)[1:, 1:] |
| | tp = _true_positive(iou, th=0.5) |
| |
|
| | fp = np.max(mask_pred) - tp |
| | fn = np.max(mask_true) - tp |
| | ap = tp / (tp + fp + fn) |
| |
|
| | |
| | return ap |
| |
|