Commit 39c2860e authored by tim.scherr's avatar tim.scherr
Browse files

Fluo-N2DH-SIM+ & Fluo-N3DH-SIM+ update

parent 8166e9e6
......@@ -10,7 +10,7 @@ from pathlib import Path
from segmentation.inference.inference import inference_2d_ctc, inference_3d_ctc
from segmentation.training.cell_segmentation_dataset import CellSegDataset
from segmentation.training.autoencoder_dataset import AutoEncoderDataset
from segmentation.training.create_training_sets import create_ctc_training_sets
from segmentation.training.create_training_sets import create_ctc_training_sets, create_sim_training_sets
from segmentation.training.mytransforms import augmentors
from segmentation.training.training import train, train_auto
from segmentation.utils import utils, unets
......@@ -122,6 +122,9 @@ def main():
path_train_sets=path_train_data,
cell_types=paths['cell_types'])
# create_sim_training_sets(path_data=path_datasets,
# path_train_sets=path_train_data)
for cell_type in cell_types:
for architecture in settings['methods']:
......@@ -133,6 +136,8 @@ def main():
model_name = '{}_{}_{}'.format(cell_type, args.mode, architecture[2])
if args.mode == 'GT' and architecture[-1]: # auto-encoder pre-training only for GT
model_name += '-auto'
if cell_type in ['Fluo-N3DH-SIM+', 'Fluo-N2DH-SIM+']: # not needed for simulated data
continue
num_trained_models = len(list(path_models.glob('{}_model*.pth'.format(model_name))))
if "all" in args.mode or "ST" in args.mode:
iterations = settings['iterations'] - num_trained_models
......
......@@ -623,12 +623,9 @@ def create_ctc_training_sets(path_data, path_train_sets, cell_types):
frame = seg_gt_id.name.split('man_seg')[-1]
seg_gt = tiff.imread(str(seg_gt_id))
img = tiff.imread(str(seg_gt_id.parents[2] / train_set / "t{}".format(frame)))
tra_gt = tiff.imread(str(seg_gt_id.parents[1] / 'TRA' / "man_track{}".format(frame)))
img, seg_gt, tra_gt = foi_correction_train(cell_type, img, seg_gt, tra_gt)
if scale != 1:
img, seg_gt, tra_gt = downscale(img=img, seg_gt=seg_gt, scale=scale, tra_gt=tra_gt)
img, seg_gt, tra_gt = downscale(img=img, seg_gt=seg_gt, scale=scale)
# min-max normalize image to 0 - 65535
img = 65535 * (img.astype(np.float32) - img.min()) / (img.max() - img.min())
......@@ -1247,3 +1244,174 @@ def create_ctc_training_sets(path_data, path_train_sets, cell_types):
json.dump(train_val_ids, outfile, ensure_ascii=False, indent=2)
return None
def create_sim_training_sets(path_data, path_train_sets):
""" Create training sets for the simulated Cell Tracking Challenge data sets Fluo-N2DH-SIM+ & Fluo-N3DH-SIM+.
(needs some revision sometime to make the code readable ...)
:param path_data: Path to the directory containing the Cell Tracking Challenge training sets.
:type path_data: Path
:param path_train_sets: Path to save the training sets into.
:type path_train_sets: Path
:return: None
"""
crop_size = 320
cell_types = ['Fluo-N2DH-SIM+', 'Fluo-N3DH-SIM+']
modes = ['GT']
for cell_type in cell_types:
for mode in modes:
# Check if data set already exists
if len(list((path_train_sets / "{}_{}".format(cell_type, mode) / 'train').glob('*.tif'))) > 0:
continue
if "all" not in mode:
print(' ... {}: {} set ...'.format(cell_type, mode))
make_train_dirs(path=path_train_sets, cell_type=cell_type, mode=mode)
if cell_type == 'Fluo-N3DH-SIM+':
shutil.copytree(str(path_train_sets / 'Fluo-N2DH-SIM+_GT' / 'train'),
str(path_train_sets / 'Fluo-N3DH-SIM+_GT' / 'train'),
dirs_exist_ok=True)
shutil.copytree(str(path_train_sets / 'Fluo-N2DH-SIM+_GT' / 'val'),
str(path_train_sets / 'Fluo-N3DH-SIM+_GT' / 'val'),
dirs_exist_ok=True)
if mode == 'GT': # Simulated data sets are always fully annotated
# If train / val split is available: use it, if not: random
train_val_ids = get_train_val_split(cell_type=cell_type, mode=mode)
# Get ids of segmentation ground truth
seg_gt_ids_01 = sorted((path_data / 'training_datasets' / cell_type / '01_GT' / 'SEG').glob('*.tif'))
seg_gt_ids_02 = sorted((path_data / 'training_datasets' / cell_type / '02_GT' / 'SEG').glob('*.tif'))
if '3D' in cell_type:
seg_gt_ids_01 = seg_gt_ids_01[::3]
seg_gt_ids_02 = seg_gt_ids_02[::2]
seg_gt_ids = seg_gt_ids_01 + seg_gt_ids_02
# Get some settings for train data generation
search_radius, min_area, max_mal, scale = get_gt_settings(gt_id_list=seg_gt_ids)
# go through files and load SEG and TRA GT
slice_idx = 0
for seg_gt_id in seg_gt_ids:
train_set = seg_gt_id.parents[1].stem.split('_')[0]
if len(seg_gt_id.stem.split('_')) > 2: # only slice annotated
frame = seg_gt_id.stem.split('_')[2] + '.tif'
slice_idx = int(seg_gt_id.stem.split('_')[3])
else:
frame = seg_gt_id.name.split('man_seg')[-1]
seg_gt = tiff.imread(str(seg_gt_id))
img = tiff.imread(str(seg_gt_id.parents[2] / train_set / "t{}".format(frame)))
if scale != 1:
img, seg_gt, _ = downscale(img=img, seg_gt=seg_gt, scale=scale)
# min-max normalize image to 0 - 65535
img = 65535 * (img.astype(np.float32) - img.min()) / (img.max() - img.min())
img = np.clip(img, 0, 65535).astype(np.uint16)
if len(seg_gt.shape) == 3:
for i in range(len(seg_gt)):
img_slice = img[i].copy()
mask = seg_gt[i].copy()
if np.max(mask) == 0: # empty slice
continue
else:
if slice_idx % 4 == 0: # do not create for each slice training data
slice_idx += 1
pass
else:
slice_idx += 1
continue
nucleus_ids = get_nucleus_ids(mask)
hlabel = np.zeros(shape=mask.shape, dtype=mask.dtype)
for nucleus_id in nucleus_ids:
hlabel += nucleus_id * binary_closing(mask == nucleus_id, np.ones((5, 5))).astype(
mask.dtype)
mask = hlabel
tr_gt_slice = mask.copy() # assumption: in 3D GT annotations all cells are annotated
generate_data(img=img_slice, mask=mask, tra_gt=tr_gt_slice, search_radius=search_radius,
max_mal=max_mal, crop_size=crop_size, cell_type=cell_type, mode=mode,
train_set=train_set, frame=frame, min_area=min_area, scale=scale,
path_train_sets=path_train_sets, slice_idx=i)
else:
if '3D' in cell_type:
img = img[slice_idx]
# Needed seed could be outside the slice --> maximum intensity projection
slice_min = np.maximum(slice_idx-2, 0)
slice_max = np.minimum(slice_idx+2, len(img)-1)
tra_gt = np.max(tra_gt[slice_min:slice_max], axis=0) # best bring seed size to min_area ...
nucleus_ids = get_nucleus_ids(seg_gt)
hlabel = np.zeros(shape=seg_gt.shape, dtype=seg_gt.dtype)
for nucleus_id in nucleus_ids:
hlabel += nucleus_id * binary_closing(seg_gt == nucleus_id, np.ones((5, 5))).astype(
seg_gt.dtype)
seg_gt = hlabel
tra_gt = seg_gt.copy()
generate_data(img=img, mask=seg_gt, tra_gt=tra_gt, search_radius=search_radius,
max_mal=max_mal, crop_size=crop_size, cell_type=cell_type, mode=mode,
train_set=train_set, frame=frame, min_area=min_area, scale=scale,
path_train_sets=path_train_sets)
train_data_info = {'scale': scale,
'max_mal': max_mal,
'min_area': min_area,
'search_radius': search_radius}
with open(path_train_sets / "{}_{}".format(cell_type, mode) / 'info.json', 'w', encoding='utf-8') as outfile:
json.dump(train_data_info, outfile, ensure_ascii=False, indent=2)
# train/val splits
img_ids = sorted((path_train_sets / "{}_{}".format(cell_type, mode) / 'A').glob('img*.tif'))
if len(img_ids) <= 30: # Use also "B" quality images when too few "A" quality images are available
img_ids_B = sorted((path_train_sets / "{}_{}".format(cell_type, mode) / 'B').glob('img*.tif'))
else:
img_ids_B = []
if not train_val_ids: # no split available
img_ids_stem = []
for idx in img_ids:
img_ids_stem.append(idx.stem.split('img_')[-1])
# Random 80%/20% split
shuffle(img_ids_stem)
train_ids = img_ids_stem[0:int(np.floor(0.8 * len(img_ids_stem)))]
val_ids = img_ids_stem[int(np.floor(0.8 * len(img_ids_stem))):]
# Add "B" quality only to train
for idx in img_ids_B:
train_ids.append(idx.stem.split('img_')[-1])
train_val_ids = {'train': train_ids, 'val': val_ids}
with open(Path.cwd() / "segmentation" / "training" / "splits" / 'ids_{}_{}.json'.format(cell_type, mode),
'w', encoding='utf-8') as outfile:
json.dump(train_val_ids, outfile, ensure_ascii=False, indent=2)
for train_mode in ['train', 'val']:
for idx in train_val_ids[train_mode]:
source_path = path_train_sets / "{}_{}".format(cell_type, mode)
target_path = path_train_sets / "{}_{}".format(cell_type, mode) / train_mode
if (source_path / "A" / ("img_{}.tif".format(idx))).exists():
source_path = source_path / "A"
else:
source_path = source_path / "B"
shutil.copyfile(str(source_path / "img_{}.tif".format(idx)),
str(target_path / "img_{}.tif".format(idx)))
shutil.copyfile(str(source_path / "dist_cell_{}.tif".format(idx)),
str(target_path / "dist_cell_{}.tif".format(idx)))
shutil.copyfile(str(source_path / "dist_neighbor_{}.tif".format(idx)),
str(target_path / "dist_neighbor_{}.tif".format(idx)))
shutil.copyfile(str(source_path / "mask_{}.tif".format(idx)),
str(target_path / "mask_{}.tif".format(idx)))
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
eval "$(conda shell.bash hook)"
conda activate /srv/scherr/virtual_environments/kit-sch-ge-2021-cell_segmentation_ve
python ./cell_segmentation.py --train --cell_type "Fluo-N3DH-SIM+" --mode "GT"
python ./cell_segmentation.py --evaluate --cell_type "Fluo-N3DH-SIM+" --mode "GT" --save_raw_pred --batch_size 8 --fuse_z_seeds --n_splitting 100
conda deactivate
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment