Commit cd580a7c authored by sp8646's avatar sp8646
Browse files

inital commit

parents
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# pycharm project file
.idea
# data files
*.csv
*.mpeg
*.lp
*.tif
*.tiff
*.out
*.spec
!run_tracking.spec
\ No newline at end of file
from pathlib import Path
def get_project_path():
return Path(__file__).parent
def get_data_path():
project_path = get_project_path()
parent_dir = project_path.parent
return parent_dir / 'data'
def get_results_path():
project_path = get_project_path()
parent_dir = project_path.parent
return parent_dir / 'Results'
# get string path
def string_path(path_arg):
if not isinstance(path_arg, str):
if hasattr(path_arg, 'as_posix'):
path_arg = path_arg.as_posix()
else:
raise TypeError('Cannot convert variable to string path')
else:
path_arg = path_arg.replace('\\', '/')
return path_arg
image_formats = ('bmp', 'jpeg', 'tif', 'png', 'tiff')
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
_libgcc_mutex=0.1=main
altgraph=0.17=pyhd3eb1b0_0
blas=1.0=mkl
blosc=1.21.0=h8c45485_0
brotli=1.0.9=he6710b0_2
brunsli=0.1=h2531618_0
bzip2=1.0.8=h7b6447c_0
ca-certificates=2021.1.19=h06a4308_0
certifi=2020.12.5=py38h06a4308_0
charls=2.1.0=he6710b0_2
cloudpickle=1.6.0=py_0
cvxopt=1.2.0=py38hfa32c7d_0
cycler=0.10.0=py38_0
cytoolz=0.11.0=py38h7b6447c_0
dask-core=2021.2.0=pyhd3eb1b0_0
decorator=4.4.2=pyhd3eb1b0_0
freetype=2.10.4=h5ab3b9f_0
giflib=5.1.4=h14c3975_1
glpk=4.65=h3ceedfd_2
gmp=6.2.1=h2531618_2
gsl=2.4=h14c3975_4
gurobi=9.1.1=py38_0
imagecodecs=2021.1.11=py38h581e88b_1
imageio=2.9.0=py_0
intel-openmp=2020.2=254
joblib=1.0.1=pyhd3eb1b0_0
jpeg=9b=h024ee3a_2
jxrlib=1.1=h7b6447c_2
kiwisolver=1.3.1=py38h2531618_0
lcms2=2.11=h396b838_0
lerc=2.2.1=h2531618_0
libaec=1.0.4=he6710b0_1
libdeflate=1.7=h27cfd23_5
libedit=3.1.20191231=h14c3975_1
libffi=3.2.1=hf484d3e_1007
libgcc-ng=9.1.0=hdf63c60_0
libgfortran-ng=7.3.0=hdf63c60_0
libpng=1.6.37=hbc83047_0
libstdcxx-ng=9.1.0=hdf63c60_0
libtiff=4.1.0=h2733197_1
libwebp=1.0.1=h8e7db2f_0
libzopfli=1.0.3=he6710b0_0
lz4-c=1.9.3=h2531618_0
macholib=1.14=pyhd3eb1b0_1
matplotlib-base=3.3.4=py38h62a2d02_0
metis=5.1.0=hf484d3e_4
mkl=2020.2=256
mkl-service=2.3.0=py38he904b0f_0
mkl_fft=1.2.1=py38h54f3939_0
mkl_random=1.1.1=py38h0573a6f_0
ncurses=6.2=he6710b0_1
networkx=2.5=py_0
numpy=1.19.2=py38h54aff64_0
numpy-base=1.19.2=py38hfa32c7d_0
olefile=0.46=py_0
openjpeg=2.3.0=h05c96fa_1
openssl=1.1.1j=h27cfd23_0
pandas=1.2.2=py38ha9443f7_0
pillow=8.1.0=py38he98fc37_0
pip=21.0.1=py38h06a4308_0
pycryptodome=3.10.1=py38h3dc18e1_0
pyinstaller=3.6=py38hbc83047_5
pyparsing=2.4.7=pyhd3eb1b0_0
python=3.8.0=h0371630_2
python-dateutil=2.8.1=pyhd3eb1b0_0
pytz=2021.1=pyhd3eb1b0_0
pywavelets=1.1.1=py38h7b6447c_2
pyyaml=5.4.1=py38h27cfd23_1
readline=7.0=h7b6447c_5
scikit-image=0.17.2=py38hdf5156a_0
scikit-learn=0.23.2=py38h0573a6f_0
scipy=1.6.1=py38h91f5cce_0
setuptools=52.0.0=py38h06a4308_0
six=1.15.0=py38h06a4308_0
snappy=1.1.8=he6710b0_0
sqlite=3.33.0=h62c20be_0
suitesparse=5.2.0=h9e4a6bb_0
tbb=2020.3=hfd86e86_0
threadpoolctl=2.1.0=pyh5ca1d4c_0
tifffile=2021.1.14=pyhd3eb1b0_1
tk=8.6.10=hbc83047_0
toolz=0.11.1=pyhd3eb1b0_0
tornado=6.1=py38h27cfd23_0
wheel=0.36.2=pyhd3eb1b0_0
xz=5.2.5=h7b6447c_0
yaml=0.2.5=h7b6447c_0
zfp=0.5.5=h2531618_4
zlib=1.2.11=h7b6447c_3
zstd=1.4.5=h9ceee32_0
from pathlib import Path
import numpy as np
from tifffile import imread
from tracker.export import ExportResults
from tracker.extract_data import get_img_files
from tracker.extract_data import get_indices_pandas
from tracker.tracking import TrackingConfig, MultiCellTracker
def run_graph2_0(img_path, segm_path, res_path, delta_t=3, default_roi_size=2):
img_path = Path(img_path)
segm_path = Path(segm_path)
res_path = Path(res_path)
img_files = get_img_files(img_path)
segm_files = get_img_files(segm_path, 'mask')
# set roi size
# assume img shape z,x,y
dummy = np.squeeze(imread(segm_files[max(segm_files.keys())]))
img_shape = dummy.shape
masks = get_indices_pandas(imread(segm_files[max(segm_files.keys())]))
m_shape = np.stack(masks.apply(lambda x: np.max(np.array(x), axis=-1) - np.min(np.array(x), axis=-1) +1))
if len(img_shape) == 2:
if len(masks) > 10:
m_size = np.median(np.stack(m_shape)).astype(int)
roi_size = tuple([m_size*default_roi_size, m_size*default_roi_size])
else:
roi_size = tuple((np.array(dummy.shape) // 10).astype(int))
else:
roi_size = tuple((np.median(np.stack(m_shape), axis=0) * default_roi_size).astype(int))
config = TrackingConfig(img_files, segm_files, roi_size, delta_t=delta_t, cut_off_distance=None)
tracker = MultiCellTracker(config)
tracks = tracker()
exporter = ExportResults()
exporter(tracks, res_path, tracker.img_shape, time_steps=sorted(img_files.keys()))
if __name__ == '__main__':
from argparse import ArgumentParser
PARSER = ArgumentParser(description='Tracking KIT-Sch-GE')
PARSER.add_argument('--image_path', type=str)
PARSER.add_argument('--delta_t', type=int, default=3)
PARSER.add_argument('--default_roi_size', type=int, default=2)
ARGS = PARSER.parse_args()
SEGM_PATH = Path(ARGS.image_path).as_posix() + '_RES'
run_graph2_0(ARGS.image_path, SEGM_PATH, SEGM_PATH, ARGS.delta_t, ARGS.default_roi_size)
import torch
from torch import nn
from scipy.ndimage import gaussian_filter
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class GaussianKernel(nn.Module):
def __init__(self, sigma, n_dims, padding=False):
super().__init__()
self.padding = padding
self.n_dims = n_dims
self.n_sigma = 3
width = max(3, np.ceil(self.n_sigma*sigma))
if width % 2 == 0:
width += 1
width = int(width)
if self.padding:
self.pad_w = width // 2
else:
self.pad_w = 0
f = np.zeros([width for _ in range(self.n_dims)]) #
if self.n_dims == 3:
f[len(f) // 2, len(f) // 2, len(f) // 2] = 1
self.conv = torch.conv3d
else:
f[len(f) // 2, len(f) // 2] = 1
self.conv = torch.conv2d
self.kernel_weights = torch.tensor(gaussian_filter(f, sigma=sigma, truncate=self.n_sigma),
device=device).float().unsqueeze(0).unsqueeze(0)
def forward(self, x):
# todo: change to reflect mode padding
return self.conv(x, self.kernel_weights, padding=self.pad_w)
class GaussianPyramid(nn.Module):
def __init__(self, n_dims, ratio, n_levels):
super().__init__()
self.ratio = 0.75 if ratio > 0.98 or ratio < 0.4 else ratio
self.n_levels = n_levels
self.sigma_base = 1 / self.ratio - 1
self.n = int(np.ceil(np.log(0.25) / np.log(self.ratio)))
self.gaussian_pyramid = [GaussianKernel(self.sigma_base * (i + 1), n_dims, True) for i in range(self.n)]
def forward(self, x):
img_pyramid = []
for i in range(self.n_levels):
if (i+1) <= self.n:
o = self.gaussian_pyramid[i](x)
rescale_size = (torch.tensor(o.shape[2:]) * self.ratio**(i+1)).int()
rescale_size[rescale_size < 1] = 1
o = nn.functional.interpolate(o, tuple(rescale_size))
else:
o = self.gaussian_pyramid[self.n-1](img_pyramid[i-self.n+1])
# todo: padd to small axis
rescale_size = (torch.tensor(o.shape[2:]) * self.ratio ** (self.n)).int()
rescale_size[rescale_size < 1] = 1
o = nn.functional.interpolate(o, tuple(rescale_size))
img_pyramid.append(o)
img_pyramid = img_pyramid[::-1]
img_pyramid.append(x)
return img_pyramid # so coarsest scale is in front
class BronxOpticalFlow(nn.Module):
def __init__(self, n_dims, ratio, n_levels, n_outer_iter, n_inner_iter=10, n_sor_iterations=10,alpha=0.3):
super().__init__()
self.n_dims = n_dims
self.ratio = ratio
self.n_levels = n_levels
self.n_outer_iter = n_outer_iter
self.n_inner_iter = n_inner_iter
self.n_sor_iterations = n_sor_iterations
self.alpha = alpha
self.gaussian_pyramid = GaussianPyramid(self.n_dims, self.ratio, self.n_levels)
# set up derivative and smoothing kernels
s = torch.tensor([0.02, 0.11, 0.74, 0.11, 0.02])
diff = -1/12 * torch.tensor([-1, 8, 0, -8, 1])
if self.n_dims == 3:
spatial_weights = [diff.reshape(-1, 1, 1), diff.reshape(1, -1, 1), diff.reshape(1, 1, -1)]
else:
spatial_weights = [diff.reshape(-1, 1), diff.reshape(1, -1)]
self.spatial_kernel_t = spatial_weights[0].unsqueeze(0).unsqueeze(0).to(device)
self.spatial_kernel_x = spatial_weights[0].unsqueeze(0).unsqueeze(0).to(device)
self.spatial_kernel_y = spatial_weights[1].unsqueeze(0).unsqueeze(0).to(device)
if self.n_dims == 3:
self.spatial_kernel_z = spatial_weights[2].unsqueeze(0).unsqueeze(0).to(device)
self.smoothing_kernel = torch.tensor(s.reshape(-1, 1, 1) * s.reshape(1, -1, 1) * s.reshape(1, 1, -1),
device=device).float().unsqueeze(0).unsqueeze(0)
self.conv = nn.functional.conv3d
else:
self.smoothing_kernel = torch.tensor(s.reshape(-1, 1) * s.reshape(1, -1),
device=device).float().unsqueeze(0).unsqueeze(0)
self.conv = nn.functional.conv2d
self.spatial_kernel_z = None
self.smoothing_kernel /= self.smoothing_kernel.sum()
def calc_derivative_2D(self, image_1, image_2):
img_1 = self.conv(image_1, self.smoothing_kernel, padding=2)
img_2 = self.conv(image_2, self.smoothing_kernel, padding=2)
dt = img_2 - img_1
fused = img_1*0.4 + img_2*0.6
dx = self.conv(fused, self.spatial_kernel_x, padding=2)
dy = self.conv(fused, self.spatial_kernel_y, padding=2)
return dx, dy, dt
def forward(self, x):
pass
def calc_derivatives_2D(self, img1,img2):
pass
#dx
#dy
#dt
# dxx
#dyy
#dxt
#dyt =
def calc_flow_2d(self, image_1, image_2, u, v):
pyramid_img_1 = self.gaussian_pyramid(image_1)
pyramid_img_2 = self.gaussian_pyramid(image_2)
u, v = torch.zeros_like(image_1, device=device)
for i, images in zip(pyramid_img_1, pyramid_img_2):
im1, im2 = images
# todo: upscale u, v -> interpolate
#init du,dv again as variables
du = torch.autograd.Variable(torch.zeros_like(u, device=device), requires_gtad=True)
dv = torch.autograd.Variable(torch.zeros_like(v, device=device), requires_gtad=True)
derivatives = self.calc_derivative_2D(im1, im2)
# psi'data
# psi' smooth
# equations
#do n LFBGS steps (== innerloop)
optim = torch.optim.LBFGS(params=[du, dv])
optim.zero_grad()
for _ in range(self.n_inner_iter):
loss = torch.nn.functional.l1_loss(..., 0)
loss.backward()
optim.step()
# update u, v
du = du.detach()
dv = dv.detach()
u += du
v += dv
@torch.jit.script
def psi(x, epsilon=0.001):
return torch.sqrt(x**2 + epsilon**2)
@torch.jit.script
def psi_derivative(x, epsilon):
return torch.pow(x**2 + epsilon**2, -0.5) * 2 * x
"""Utilities to export tracking results to ctc format"""
import os
import numpy as np
import pandas as pd
from tifffile import imsave
from tracker.postprocessing import add_dummy_masks, untangle_tracks
from tracker.postprocessing import no_fn_correction, no_untangling
class ExportResults:
def __init__(self, postprocessing_key=None):
"""
Exports tracking results to ctc format.
Args:
postprocessing_key: optional string to remove post-processing steps,
if none is provided both post-processing steps (untangling, FN correction) are applied.
'nd': 'no untangling',
'ns+l': 'no FN correction but keep link of fragmented track as predecessor-successor',
'ns-l': 'no FN correction and no link',
'nd_ns+l': 'no untangling and no FN correction
but keep link of fragmented track as predecessor-successor',
'nd_ns-l': 'no untangling and no FN correction and no link'
"""
self.img_file_name = 'mask'
self.img_file_ending = '.tif'
self.track_file_name = 'res_track.txt'
self.time_steps = None
self.postprocessing_key = postprocessing_key
def __call__(self, tracks, export_dir, img_shape, time_steps):
"""
Post-processes a tracking result and exports it to the ctc format of tracking masks and a lineage file.
Args:
tracks: a dict containing the trajectories
export_dir: a path where to store the exported tracking results
img_shape: a tuple proving the shape for the tracking masks
time_steps: a list of time steps
"""
if not os.path.exists(export_dir):
os.makedirs(export_dir)
self.time_steps = time_steps
if not os.path.exists(export_dir):
os.makedirs(export_dir)
self.time_steps = time_steps
if self.postprocessing_key == 'nd':
tracks = no_untangling(tracks)
print('add dummy masks')
tracks = add_dummy_masks(tracks, img_shape)
elif self.postprocessing_key == 'ns+l':
print('untangle')
tracks = untangle_tracks(tracks)
tracks = no_fn_correction(tracks, keep_link=True)
elif self.postprocessing_key == 'ns-l':
print('untangle')
tracks = untangle_tracks(tracks)
tracks = no_fn_correction(tracks, keep_link=False)
elif self.postprocessing_key == 'nd_ns+l':
tracks = no_untangling(tracks)
tracks = no_fn_correction(tracks, keep_link=True)
elif self.postprocessing_key == 'nd_ns-l':
tracks = no_untangling(tracks)
tracks = no_fn_correction(tracks, keep_link=False)
else:
print('untangle')
tracks = untangle_tracks(tracks)
print('add dummy masks')
tracks = add_dummy_masks(tracks, img_shape)
tracks = catch_tra_issues(tracks, time_steps)
print('export masks')
self.create_lineage_file(tracks, export_dir)
self.create_segm_masks(tracks, export_dir, img_shape)
def create_lineage_file(self, tracks, export_dir):
"""
Creates the lineage file.
Args:
tracks: a dict containing the trajectories
export_dir: path to the folder where the results shall be stored
"""
track_info = {'track_id': [], 't_start': [], 't_end': [], 'predecessor_id': []}
for t_id in sorted(tracks.keys()):
track_data = tracks[t_id]
track_info['track_id'].append(track_data.track_id)
frame_ids = sorted(list(track_data.masks.keys()))
track_info['t_start'].append(frame_ids[0])
track_info['t_end'].append(frame_ids[-1])
if isinstance(track_data.pred_track_id, list):
if len(track_data.pred_track_id) > 0:
track_data.pred_track_id = track_data.pred_track_id[0]
else:
track_data.pred_track_id = 0 # no predecessor
track_info['predecessor_id'].append(track_data.pred_track_id)
df = pd.DataFrame.from_dict(track_info)
df.to_csv(os.path.join(export_dir, self.track_file_name),
columns=["track_id", "t_start", "t_end", 'predecessor_id'],
sep=' ', index=False, header=False)
def create_segm_masks(self, all_tracks, export_dir, img_shape):
"""
Creates for each time step a tracking image with masks
corresponding to the segmented and tracked objects.
Args:
all_tracks: a dict containing the trajectories
export_dir: a path where to store the exported tracking results
img_shape: a tuple proving the shape for the tracking masks
"""
tracks_in_frame = {}
# create for each time step dict entry, otherwise missing time steps possible -> no img exported
for t_step in self.time_steps:
if t_step not in tracks_in_frame:
tracks_in_frame[t_step] = []
for track_data in all_tracks.values():
time_steps = sorted(list(track_data.masks.keys()))
for t_step in time_steps:
if t_step not in tracks_in_frame:
tracks_in_frame[t_step] = []
tracks_in_frame[t_step].append(track_data.track_id)
t_max = sorted(list(tracks_in_frame.keys()))[-1]
z_fill = np.int(np.ceil(max(np.log10(max(1, t_max)), 3))) # either 3 or 4 digits long frame id
for time, track_ids in tracks_in_frame.items():
tracking_mask = create_tracking_mask_image(all_tracks, time, track_ids, img_shape)
file_name = self.img_file_name + str(time).zfill(z_fill) + self.img_file_ending
tracking_mask = np.array(np.squeeze(tracking_mask), dtype=np.uint16)
imsave(os.path.join(export_dir, file_name), tracking_mask, compress=1)
def create_tracking_mask_image(all_tracks, time, track_ids, img_shape):
"""
Constructs image containing tracking masks and resolves overlapping masks.
Args:
all_tracks: a dict containing the trajectories
time: int indicating the time point
track_ids: list of track ids at the selected time point
img_shape: a tuple providing the image shape of the mask image
Returns: an np.array with the tracking masks for a time point
"""
all_masks = {}
all_mask_center = []
all_mask_ids = []
tracking_mask = np.zeros((1, *img_shape), dtype=np.uint16)
for t_id in track_ids:
track = all_tracks[t_id]
mask = track.masks[time]
all_masks[t_id] = mask
mask_median = np.median(mask, axis=-1)
if not all_mask_center:
all_mask_center.append(mask_median)
elif (not np.any(np.all((mask_median == all_mask_center), axis=-1))) or (len(all_mask_center) == 0):
all_mask_center.append(mask_median)
else:
dist = np.linalg.norm(np.array(mask) - mask_median.reshape(-1, 1), axis=0)
sorted_ids = np.argsort(dist, axis=0)
sorted_mask = np.array(mask)[:, sorted_ids]
index_nearest_point = np.argmin([np.any(np.all(el == all_mask_center, axis=-1))
for el in sorted_mask.transpose()])
all_mask_center.append(sorted_mask[:, index_nearest_point])
all_mask_ids.append(t_id)
# due to interpolated masks: overlapping with other masks possible -> reassign overlapping pixels
colliding_pixels = np.array([np.any(img_plane[mask] > 0, axis=0) for img_plane in tracking_mask])
if np.all(colliding_pixels > 0):
# add new plane
tracking_mask = np.vstack([tracking_mask, np.zeros((1, *img_shape), dtype=np.uint16)])
img_plane = tracking_mask[-1]
else:
# add colliding pixels to first plane without collision
# split selection of plane and mask indices as otherwise mask indices considered matrix
# as p reference on tracking mask, tracking mask is edited
img_plane = tracking_mask[np.argmax(colliding_pixels == 0)]
img_plane[mask] = t_id
if tracking_mask.shape[0] > 1: # colliding pixels
is_collision = np.sum(tracking_mask > 0, axis=0) > 1
single_plane = tracking_mask.copy()
single_plane[:, is_collision] = 0
single_plane = np.sum(single_plane, axis=0)
all_mask_ids = np.array(all_mask_ids)
all_mask_center = np.array(all_mask_center)
ind_pixel = list(zip(*np.where(is_collision)))
pixel_masks = tracking_mask[:, is_collision].T
for pixel_ind, masks_pixel in zip(ind_pixel, pixel_masks):
# sort as all_m_ids sorted as well-> otherwise swaps in m_ids possible
m_ids = sorted(masks_pixel[masks_pixel > 0])
mask_centers = all_mask_center[np.isin(all_mask_ids, m_ids)]
dist = np.sqrt(np.sum(np.square(mask_centers - np.array(pixel_ind).reshape(1, -1)), axis=-1))
single_plane[pixel_ind] = m_ids[np.argmin(dist)]
# add unmerged masks to tracking masks - for each tracked object a segm masks now in img
tracking_mask = single_plane
return tracking_mask
def catch_tra_issues(tracks, time_steps):
"""
Adds for each empty tracking frame the tracking result of the temporally closest frame.
Otherwise CTC measure can yield an error.
Args:
tracks: a dict containing the tracking results
time_steps: a list of time steps
Returns: the modified tracks
"""
tracks_in_frame = {}
for track_data in tracks.values():
track_timesteps = sorted(list(track_data.masks.keys()))
for t_step in track_timesteps:
if t_step not in tracks_in_frame:
tracks_in_frame[t_step] = []
tracks_in_frame[t_step].append(track_data.track_id)
if sorted(time_steps) != sorted(list(tracks_in_frame.keys())):
empty_timesteps = sorted(np.array(time_steps)[~np.isin(time_steps, list(tracks_in_frame.keys()))])
filled_timesteps = np.array(sorted(list(tracks_in_frame.keys())))
for empty_frame in empty_timesteps:
nearest_filled_frame = filled_timesteps[np.argmin(abs(filled_timesteps-empty_frame))]
track_ids = tracks_in_frame[nearest_filled_frame]
for track_id in track_ids:
tracks[track_id].masks[empty_frame] = tracks[track_id].masks[nearest_filled_frame]
tracks_in_frame[empty_frame] = track_ids
filled_timesteps = np.array(sorted(list(tracks_in_frame.keys())))
return tracks