Commit 5c7626fd authored by katharina.loeffler's avatar katharina.loeffler
Browse files

handle images with additional single channels

parent 32bbd6fa
......@@ -20,7 +20,7 @@ def run_tracker(img_path, segm_path, res_path, delta_t=3, default_roi_size=2):
# assume img shape z,x,y
dummy = np.squeeze(imread(segm_files[max(segm_files.keys())]))
img_shape = dummy.shape
masks = get_indices_pandas(imread(segm_files[max(segm_files.keys())]))
masks = get_indices_pandas(imread(segm_files[max(segm_files.keys())]).squeeze())
m_shape = np.stack(masks.apply(lambda x: np.max(np.array(x), axis=-1) - np.min(np.array(x), axis=-1) +1))
if len(img_shape) == 2:
......@@ -37,8 +37,9 @@ def run_tracker(img_path, segm_path, res_path, delta_t=3, default_roi_size=2):
tracker = MultiCellTracker(config)
tracks = tracker()
segm_mask_shape = imread(segm_files[max(segm_files.keys())]).shape
exporter = ExportResults()
exporter(tracks, res_path, tracker.img_shape, time_steps=sorted(img_files.keys()))
exporter(tracks, res_path, segm_mask_shape, time_steps=sorted(img_files.keys()))
if __name__ == '__main__':
......
......@@ -126,11 +126,13 @@ class ExportResults:
t_max = sorted(list(tracks_in_frame.keys()))[-1]
z_fill = np.int(np.ceil(max(np.log10(max(1, t_max)), 3))) # either 3 or 4 digits long frame id
squeezed_img_shape = np.array(img_shape) # remove single channels
squeezed_img_shape = tuple(squeezed_img_shape[squeezed_img_shape > 1])
for time, track_ids in tracks_in_frame.items():
tracking_mask = create_tracking_mask_image(all_tracks, time, track_ids, img_shape)
tracking_mask = create_tracking_mask_image(all_tracks, time, track_ids, squeezed_img_shape)
file_name = self.img_file_name + str(time).zfill(z_fill) + self.img_file_ending
tracking_mask = np.array(np.squeeze(tracking_mask), dtype=np.uint16)
tracking_mask = np.array(np.squeeze(tracking_mask), dtype=np.uint16).reshape(img_shape)
imsave(os.path.join(export_dir, file_name), tracking_mask, compress=1)
......
......@@ -68,7 +68,7 @@ class MultiCellTracker:
def propagate_tracklets(self, time):
"""Propagates object position and features over time."""
image = imread(self.config.get_image_file(time))
image = imread(self.config.get_image_file(time)).squeeze()
if self.img_shape is None:
self.img_shape = image.shape
segmentation, mask_indices = self.config.get_segmentation_masks(time)
......@@ -391,7 +391,7 @@ class TrackingConfig:
return self.img_files[time_step]
def get_segmentation_masks(self, time_step):
segmentation = imread(self.segm_files[time_step])
segmentation = imread(self.segm_files[time_step]).squeeze()
segmentation = np.squeeze(segmentation)
return segmentation, get_indices_pandas(segmentation)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment