Commit 8166e9e6 authored by tim.scherr's avatar tim.scherr
Browse files

initial commit

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# pycharm project file
.idea
# data files
*.csv
*.mpeg
*__pycache__
*.idea
*.pdf
*.pth
# image data
*.tif
*.tiff
*.png
Copyright 2021 Tim Scherr <tim.scherr@kit.edu>
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
\ No newline at end of file
# KIT-Sch-GE 2021 Segmentation
Segmentation method used for our submission to the 6th edition of the [ISBI Cell Tracking Challenge](http://celltrackingchallenge.net/) 2021 (Team KIT-Sch-GE).
## Prerequisites
* [Anaconda Distribution](https://www.anaconda.com/products/individual)
* A CUDA capable GPU
* Minimum / recommended RAM: 16 GiB / 32 GiB
* Minimum / recommended VRAM: 12 GiB / 24 GiB
## Installation
Clone the repository:
```
git clone https://git.scc.kit.edu/KIT-Sch-GE/2021_segmentation.git
```
Open the Anaconda Prompt (Windows) or the Terminal (Linux), go to the repository and create a new virtual environment:
```
cd path_to_the_cloned_repository
conda env create -f requirements.yml
```
Activate the virtual environment kit_sch-ge-2021_cell_segmentation_ve:
```
conda activate kit_sch-ge-2021_cell_segmentation_ve
```
## Cell Tracking Challenge 2021
In this section, it is described how to reproduce the segmentation results of our Cell Tracking Challenge submission. If the exact submission results are needed, download our trained models from the Cell Tracking Challenge website when available (and move them to *cell_tracking_challenge/kit-sch-ge_2021/SW/models*, see also next step).
### Data
Download the Cell Tracking Challenge training and challenge data sets. Make a folder *cell_tracking_challenge*. Unzip the training data sets into *cell_tracking_challenge/training_datasets*. Unzip the training data sets into *cell_tracking_challenge/challenge_datasets*. Download and unzip the [evaluation software](http://public.celltrackingchallenge.net/software/EvaluationSoftware.zip). Set the corresponding paths in *paths.json*.
### Training
After downloading the required Cell Tracking Challenge data, new models can be trained with:
```
python cell_segmentation.py --train --cell_type "cell_type" --mode "mode"
```
Thereby, the needed training data will be created automatically using the train/val splits in *2021_segmentation/segmentation/training/splits* (takes some time). To use new random splits, just delete all json files in the corresponding folder (and the training sets if already created).
The batch size and how many models are trained per *cell_type* and *mode* (GT, ST, GT+ST, allGT, allST, allGT+allST, depending on which label type should be used) can be adjusted in *cell_segmentation_train_settings.json*. With the standard setting a model is trained with the Adam optimizer and a model with the [Ranger](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) optimizer. For the mode "GT", two models are trained each and two Ranger models with an autoencoder pre-training of the encoder.
*train_eval.sh* is a bash script for the training and evaluation of our whole submission (takes some time!).
### Evaluation
Trained models can be evaluated on the training datasets with:
```
python cell_segmentation.py --evaluate --cell_type "cell_type" --mode "mode"
```
Some raw predictions can be saved with *--save_raw_pred*. The batch size can be set with *--batch_size $int*. For some cell types an artifact correction (*--artifact_correction*) or the fusion of seeds (in *z* direction, --fuse_z_seeds) can be helpful. The mask and marker thresholds to be evaluated can be found in *cell_segmentation.py*. For the settings ST and allST the SEG score calculated on the provided STs is used to find the best model. For the other cases, the OP_CSB is used on the provided GT data.
The best models are copied to *cell_tracking_challenge/kit-sch-ge_2021/SW/models*. In the corresponding json files, the best thresholds and the applied scaling factor can be found (and also some other information). The results of the best model are copied to *cell_tracking_challenge/training_datasets/cell_type/Kit-Sch-GE_2021/mode/csb/*. The other results can be found in the specified result_path.
### Inference
For inference, use the copied best models and the corresponding parameters:
```
python cell_segmentation.py --inference --cell_type "cell_type" --mode "mode" --save_raw_pred --batch_size $int --th_cell $float --th_seed $float (--artifact_correction --fuse_z_seeds --apply_clahe --scale $float --multi_gpu)
```
*inference.sh* is a bash script with the parameters we used for our submission (use also our trained models).
## Publication ##
T. Scherr, K. Löffler, M. Böhland, and R. Mikut (2020). Cell Segmentation and Tracking using CNN-Based Distance Predictions and a Graph-Based Matching Strategy. PLoS ONE 15(12). DOI: [10.1371/journal.pone.0243219](https://doi.org/10.1371/journal.pone.0243219).
## License ##
This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details.
\ No newline at end of file
This diff is collapsed.
{
"methods":
[
[["DU", "conv", "relu", "bn", [64, 1024]], "distance", "adam", "smooth_l1", null],
[["DU", "conv", "mish", "bn", [64, 1024]], "distance", "ranger", "smooth_l1", null],
[["DU", "conv", "mish", "bn", [64, 1024]], "distance", "ranger", "smooth_l1", true]
],
"batch_size": 8,
"batch_size_auto": 2,
"iterations": 1,
"iterations_GT_single_celltype": 2
}
\ No newline at end of file
eval "$(conda shell.bash hook)"
conda activate kit-sch-ge-2021-cell_segmentation_ve
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35 --scale 0.5
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --scale 0.6 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "GT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "GT" --save_raw_pred --batch_size 16 --th_cell 0.09 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "allGT" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "allGT" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --artifact_correction
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45 --fuse_z_seeds
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "GT+ST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.35
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "GT+ST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-HSC" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "BF-C2DL-MuSC" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "DIC-C2DH-HeLa" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-C2DL-MSC" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-A549" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DH-H157" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-C3DL-MDA231" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DH-GOWT1" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N2DL-HeLa" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CE" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "Fluo-N3DH-CHO" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "PhC-C2DH-U373" --mode "allGT+allST" --save_raw_pred --batch_size 8 --th_cell 0.07 --th_seed 0.45
python ./cell_segmentation.py --inference --cell_type "PhC-C2DL-PSC" --mode "allGT+allST" --save_raw_pred --batch_size 16 --th_cell 0.07 --th_seed 0.45
conda deactivate
{
"cell_types":
[
"BF-C2DL-HSC",
"BF-C2DL-MuSC",
"DIC-C2DH-HeLa",
"Fluo-C2DL-MSC",
"Fluo-C3DH-A549",
"Fluo-C3DH-H157",
"Fluo-C3DL-MDA231",
"Fluo-N2DH-GOWT1",
"Fluo-N2DL-HeLa",
"Fluo-N3DH-CE",
"Fluo-N3DH-CHO",
"PhC-C2DH-U373",
"PhC-C2DL-PSC"
],
"path_ctc_metric": ".../EvaluationSoftware/",
"path_data": ".../cell_tracking_challenge/",
"path_results": ".../kit-sch-ge_2021_segmentation/"
}
\ No newline at end of file
name: kit_sch-ge-2021_cell_segmentation_ve
channels:
- pytorch
- conda-forge
- anaconda
dependencies:
- cudnn
- cudatoolkit=11.0
- imgaug
- ipython
- numpy
- opencv
- pandas
- pillow
- python=3.8
- pytorch=1.7
- scikit-image
- scipy
- imageio
- tifffile
- torchvision
- pip
- pip:
- imagecodecs
\ No newline at end of file
import numpy as np
import tifffile as tiff
import torch
from skimage.exposure import equalize_adapthist
from skimage.transform import rescale
from torch.utils.data import Dataset
from torchvision import transforms
from segmentation.utils.utils import zero_pad_model_input
class CTCDataSet(Dataset):
""" Pytorch data set for Cell Tracking Challenge data. """
def __init__(self, data_dir, transform=lambda x: x):
"""
:param data_dir: Directory with the Cell Tracking Challenge images to predict (e.g., t001.tif)
:param transform:
"""
self.img_ids = sorted(data_dir.glob('t*.tif'))
self.transform = transform
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
img = tiff.imread(str(img_id))
sample = {'image': img,
'id': img_id.stem}
sample = self.transform(sample)
return sample
def pre_processing_transforms(apply_clahe, scale_factor):
""" Get transforms for the CTC data set.
:param apply_clahe: apply CLAHE.
:type apply_clahe: bool
:param scale_factor: Downscaling factor <= 1.
:type scale_factor: float
:return: transforms
"""
data_transforms = transforms.Compose([ContrastEnhancement(apply_clahe),
Normalization(),
Scaling(scale_factor),
Padding(),
ToTensor()])
return data_transforms
class ContrastEnhancement(object):
def __init__(self, apply_clahe):
self.apply_clahe = apply_clahe
def __call__(self, sample):
if self.apply_clahe:
img = sample['image']
img = equalize_adapthist(np.squeeze(img), clip_limit=0.01)
img = (65535 * img).astype(np.uint16)
sample['image'] = img
return sample
class Normalization(object):
def __call__(self, sample):
img = sample['image']
img = 2 * (img.astype(np.float32) - img.min()) / (img.max() - img.min()) - 1
sample['image'] = img
return sample
class Padding(object):
def __call__(self, sample):
img = sample['image']
img, pads = zero_pad_model_input(img=img, pad_val=np.min(img))
sample['image'] = img
sample['pads'] = pads
return sample
class Scaling(object):
def __init__(self, scale):
self.scale = scale
def __call__(self, sample):
img = sample['image']
sample['original_size'] = img.shape
if self.scale < 1:
if len(img.shape) == 3:
img = rescale(img, (1, self.scale, self.scale), order=2, preserve_range=True).astype(img.dtype)
else:
img = rescale(img, (self.scale, self.scale), order=2, preserve_range=True).astype(img.dtype)
sample['image'] = img
return sample
class ToTensor(object):
""" Convert image and label image to Torch tensors """
def __call__(self, sample):
img = sample['image']
if len(img.shape) == 2:
img = img[None, :, :]
img = torch.from_numpy(img).to(torch.float)
return img, sample['id'], sample['pads'], sample['original_size']
import gc
import json
import tifffile as tiff
import torch
from scipy.ndimage import binary_dilation
from skimage.measure import regionprops, label
from skimage.transform import resize
from segmentation.inference.ctc_dataset import CTCDataSet, pre_processing_transforms
from segmentation.inference.postprocessing import *
from segmentation.utils.unets import build_unet
def inference_2d_ctc(model, data_path, result_path, device, batchsize, args, num_gpus=None):
""" Inference function for 2D Cell Tracking Challenge data sets.
:param model: Path to the model to use for inference.
:type model: pathlib Path object.
:param data_path: Path to the directory containing the Cell Tracking Challenge data sets.
:type data_path: pathlib Path object
:param result_path: Path to the results directory.
:type result_path: pathlib Path object
:param device: Use (multiple) GPUs or CPU.
:type device: torch device
:param batchsize: Batch size.
:type batchsize: int
:param args: Arguments for post-processing.
:type args:
:param num_gpus: Number of GPUs to use in GPU mode (enables larger batches)
:type num_gpus: int
:return: None
"""
# Load model json file to get architecture + filters
with open(model.parent / (model.stem + '.json')) as f:
model_settings = json.load(f)
# Build model
net = build_unet(unet_type=model_settings['architecture'][0],
act_fun=model_settings['architecture'][2],
pool_method=model_settings['architecture'][1],
normalization=model_settings['architecture'][3],
device=device,
num_gpus=num_gpus,
ch_in=1,
ch_out=1,
filters=model_settings['architecture'][4])
# Get number of GPUs to use and load weights
if not num_gpus:
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
net.module.load_state_dict(torch.load(str(model), map_location=device))
else:
net.load_state_dict(torch.load(str(model), map_location=device))
net.eval()
torch.set_grad_enabled(False)
# Get images to predict
ctc_dataset = CTCDataSet(data_dir=data_path,
transform=pre_processing_transforms(apply_clahe=args.apply_clahe,
scale_factor=args.scale))
dataloader = torch.utils.data.DataLoader(ctc_dataset,
batch_size=batchsize,
shuffle=False,
pin_memory=True,
num_workers=8)
# Predict images (iterate over images/files)
for sample in dataloader:
img_batch, ids_batch, pad_batch, img_size = sample
img_batch = img_batch.to(device)
if batchsize > 1: # all images in a batch have same dimensions and pads
pad_batch = [pad_batch[i][0] for i in range(len(pad_batch))]
img_size = [img_size[i][0] for i in range(len(img_size))]
# Prediction
prediction_border_batch, prediction_cell_batch = net(img_batch)
# Get rid of pads
prediction_cell_batch = prediction_cell_batch[:, 0, pad_batch[0]:, pad_batch[1]:, None].cpu().numpy()
prediction_border_batch = prediction_border_batch[:, 0, pad_batch[0]:, pad_batch[1]:, None].cpu().numpy()
# Save also some raw predictions (not all since float32 --> needs lot of memory)
save_ids = [0, len(ctc_dataset) // 8, len(ctc_dataset) // 4, 3 * len(ctc_dataset) // 8, len(ctc_dataset) // 2,
5 * len(ctc_dataset), 3 * len(ctc_dataset) // 4, 7 * len(ctc_dataset) // 8, len(ctc_dataset) - 1]
# Go through predicted batch and apply post-processing (not parallelized)
for h in range(len(prediction_border_batch)):
print(' ... processing {0} ...'.format(ids_batch[h]))
# Get actual file number:
file_num = int(ids_batch[h].split('t')[-1])
# Save not all raw predictions to save memory
if file_num in save_ids and args.save_raw_pred:
save_raw_pred = True
else:
save_raw_pred = False
file_id = ids_batch[h].split('t')[-1] + '.tif'
if model_settings['label_type'] == 'distance':
prediction_instance, border = distance_postprocessing(border_prediction=prediction_border_batch[h],
cell_prediction=prediction_cell_batch[h],
args=args)
if args.scale < 1:
prediction_instance = resize(prediction_instance,
img_size,
order=0,
preserve_range=True,
anti_aliasing=False).astype(np.uint16)
prediction_instance = foi_correction(mask=prediction_instance, cell_type=args.cell_type)
tiff.imsave(str(result_path / ('mask' + file_id)), prediction_instance, compress=1)
if save_raw_pred:
tiff.imsave(str(result_path / ('cell' + file_id)), prediction_cell_batch[h, ..., 0].astype(np.float32), compress=1)
tiff.imsave(str(result_path / ('raw_border' + file_id)), prediction_border_batch[h, ..., 0].astype(np.float32), compress=1)
tiff.imsave(str(result_path / ('border' + file_id)), border.astype(np.float32), compress=1)
if args.artifact_correction:
# Artifact correction based on the assumption that the cells are dense and artifacts far away
roi = np.zeros_like(prediction_instance) > 0
prediction_instance_ids = sorted(result_path.glob('mask*'))
for prediction_instance_id in prediction_instance_ids:
roi = roi | (tiff.imread(str(prediction_instance_id)) > 0)
roi = binary_dilation(roi, np.ones(shape=(20, 20)))
roi = label(roi)
props = regionprops(roi)
# Keep only the largest region
largest_area, largest_area_id = 0, 0
for prop in props:
if prop.area > largest_area:
largest_area = prop.area
largest_area_id = prop.label
roi = (roi == largest_area_id)
for prediction_instance_id in prediction_instance_ids:
prediction_instance = tiff.imread(str(prediction_instance_id))
prediction_instance = prediction_instance * roi
tiff.imsave(str(prediction_instance_id), prediction_instance.astype(np.uint16), compress=1)
# Clear memory
del net
gc.collect()
return None
def inference_3d_ctc(model, data_path, result_path, device, batchsize, args, num_gpus=None):
""" Inference function for 2D Cell Tracking Challenge data sets.
:param model: Path to the model to use for inference.
:type model: pathlib Path object.
:param data_path: Path to the directory containing the Cell Tracking Challenge data sets.
:type data_path: pathlib Path object
:param result_path: Path to the results directory.
:type result_path: pathlib Path object
:param device: Use (multiple) GPUs or CPU.
:type device: torch device
:param batchsize: Batch size.
:type batchsize: int
:param args: Arguments for post-processing.
:type args: object
:param num_gpus: Number of GPUs to use in GPU mode (enables larger batches)
:type num_gpus: int
:return: None
"""
# Load model json file to get architecture + filters
with open(model.parent / (model.stem + '.json')) as f:
model_settings = json.load(f)
# Build model
net = build_unet(unet_type=model_settings['architecture'][0],
act_fun=model_settings['architecture'][2],
pool_method=model_settings['architecture'][1],
normalization=model_settings['architecture'][3],
device=device,
num_gpus=num_gpus,
ch_in=1,
ch_out=1,
filters=model_settings['architecture'][4])
# Get number of GPUs to use and load weights
if not num_gpus:
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
net.module.load_state_dict(torch.load(str(model), map_location=device))
else:
net.load_state_dict(torch.load(str(model), map_location=device))
net.eval()
torch.set_grad_enabled(False)
# Get images to predict
ctc_dataset = CTCDataSet(data_dir=data_path,
transform=pre_processing_transforms(apply_clahe=args.apply_clahe,
scale_factor=args.scale))
dataloader = torch.utils.data.DataLoader(ctc_dataset,
batch_size=batchsize,
shuffle=False,