Commit 47246dd0 authored by tim.scherr's avatar tim.scherr
Browse files

merging postprocessing added

parent 56d38703
......@@ -125,6 +125,7 @@ The best model (OP_CSB measure for GT & GT+ST, SEG measure calculated on ST for
### Parameters
- <tt>--apply_clahe</tt> / <tt>-acl</tt>: CLAHE pre-processing.
- <tt>--apply_merging</tt> / <tt>-am</tt>: Merging post-processing (only 2D, can resolve oversegmentation but may lead to undersegmentation)
- <tt>--artifact_correction</tt> / <tt>-ac</tt>: Motion-based artifact correction post-processing (only for 2D and dense data).
- <tt>--batch_size</tt> / <tt>-bs</tt>: batch size (**8**).
- <tt>--fuse_z_seeds</tt> / <tt>-fzs</tt>: Fuse seeds in axial direction (only for 3D).
......@@ -166,6 +167,7 @@ The results can be found in *./challenge_datasets/cell_type*.
### Parameters
- <tt>--apply_clahe</tt> / <tt>-acl</tt>: CLAHE pre-processing.
- <tt>--apply_merging</tt> / <tt>-am</tt>: Merging post-processing (only 2D, can resolve oversegmentation but may lead to undersegmentation)
- <tt>--artifact_correction</tt> / <tt>-ac</tt> : Motion-based artifact correction post-processing (only for 2D and dense data).
- <tt>--batch_size</tt> / <tt>-bs</tt>: batch size (**8**).
- <tt>--fuse_z_seeds</tt> / <tt>-fzs</tt>: Fuse seeds in axial direction (only for 3D).
......
......@@ -20,7 +20,7 @@ class EvalArgs(object):
"""
def __init__(self, th_cell, th_seed, n_splitting, apply_clahe, scale, cell_type, save_raw_pred,
artifact_correction, fuse_z_seeds):
artifact_correction, apply_merging, fuse_z_seeds):
"""
:param th_cell: Mask / cell size threshold.
......@@ -51,6 +51,7 @@ class EvalArgs(object):
self.save_raw_pred = save_raw_pred
self.artifact_correction = artifact_correction
self.fuse_z_seeds = fuse_z_seeds
self.apply_merging = apply_merging
def main():
......@@ -61,6 +62,7 @@ def main():
# Get arguments
parser = argparse.ArgumentParser(description='KIT-Sch-GE 2021 Cell Segmentation - Evaluation')
parser.add_argument('--apply_clahe', '-acl', default=False, action='store_true', help='CLAHE pre-processing')
parser.add_argument('--apply_merging', '-am', default=False, action='store_true', help='Merging post-processing')
parser.add_argument('--artifact_correction', '-ac', default=False, action='store_true', help='Artifact correction')
parser.add_argument('--batch_size', '-bs', default=8, type=int, help='Batch size')
parser.add_argument('--cell_type', '-ct', nargs='+', required=True, help='Cell type(s)')
......@@ -173,6 +175,7 @@ def main():
apply_clahe=args.apply_clahe, scale=scale_factor, cell_type=ct,
save_raw_pred=args.save_raw_pred,
artifact_correction=args.artifact_correction,
apply_merging=args.apply_merging,
fuse_z_seeds=args.fuse_z_seeds)
if '2D' in ct:
......
......@@ -19,6 +19,7 @@ def main():
# Get arguments
parser = argparse.ArgumentParser(description='KIT-Sch-GE 2021 Cell Segmentation - Inference')
parser.add_argument('--apply_clahe', '-acl', default=False, action='store_true', help='CLAHE pre-processing')
parser.add_argument('--apply_merging', '-am', default=False, action='store_true', help='Merging post-processing')
parser.add_argument('--artifact_correction', '-ac', default=False, action='store_true', help='Artifact correction')
parser.add_argument('--batch_size', '-bs', default=8, type=int, help='Batch size')
parser.add_argument('--cell_type', '-ct', nargs='+', required=True, help='Cell type(s) to predict')
......
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.ndimage import gaussian_filter, binary_dilation
from skimage.segmentation import watershed
from skimage import measure
from skimage.feature import peak_local_max
from skimage.feature import peak_local_max, canny
from skimage.morphology import binary_closing
from segmentation.utils.utils import get_nucleus_ids
......@@ -145,6 +146,23 @@ def distance_postprocessing(border_prediction, cell_prediction, args, input_3d=F
# Marker-based watershed
prediction_instance = watershed(image=-cell_prediction, markers=seeds, mask=mask, watershed_line=False)
if args.apply_merging and np.max(prediction_instance) < 255:
# Get borders between touching cells
label_bin = prediction_instance > 0
pred_boundaries = cv2.Canny(prediction_instance.astype(np.uint8), 1, 1) > 0
pred_borders = cv2.Canny(label_bin.astype(np.uint8), 1, 1) > 0
pred_borders = pred_boundaries ^ pred_borders
pred_borders = measure.label(pred_borders)
for border_id in get_nucleus_ids(pred_borders):
pred_border = (pred_borders == border_id)
if np.sum(border_prediction[pred_border]) / np.sum(pred_border) < 0.075: # very likely splitted due to shape
# Get ids to merge
pred_border_dilated = binary_dilation(pred_border, np.ones(shape=(3, 3), dtype=np.uint8))
merge_ids = get_nucleus_ids(prediction_instance[pred_border_dilated])
if len(merge_ids) == 2:
prediction_instance[prediction_instance == merge_ids[1]] = merge_ids[0]
prediction_instance = measure.label(prediction_instance)
# Iterative splitting of cells detected as (probably) merged
if apply_splitting:
props = measure.regionprops(prediction_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment