model.py 3.89 KB
Newer Older
qubvel's avatar
qubvel committed
1 2
from .builder import build_linknet
from ..utils import freeze_model
qubvel's avatar
qubvel committed
3
from ..utils import legacy_support
qubvel's avatar
qubvel committed
4
from ..backbones import get_backbone, get_feature_layers
qubvel's avatar
qubvel committed
5

qubvel's avatar
qubvel committed
6 7 8 9 10 11 12
old_args_map = {
    'freeze_encoder': 'encoder_freeze',
    'skip_connections': 'encoder_features',
    'upsample_layer': 'decoder_block_type',
    'n_upsample_blocks': None,  # removed
    'input_tensor': None,  # removed
    'upsample_kernel_size': None,  # removed
qubvel's avatar
qubvel committed
13 14 15
}


qubvel's avatar
qubvel committed
16
@legacy_support(old_args_map)
qubvel's avatar
qubvel committed
17 18
def Linknet(backbone_name='vgg16',
            input_shape=(None, None, 3),
qubvel's avatar
qubvel committed
19 20
            classes=1,
            activation='sigmoid',
qubvel's avatar
qubvel committed
21
            encoder_weights='imagenet',
qubvel's avatar
qubvel committed
22 23
            encoder_freeze=False,
            encoder_features='default',
qubvel's avatar
qubvel committed
24
            decoder_filters=(None, None, None, None, 16),
qubvel's avatar
qubvel committed
25
            decoder_use_batchnorm=True,
qubvel's avatar
qubvel committed
26 27
            decoder_block_type='upsampling',
            **kwargs):
qubvel's avatar
qubvel committed
28 29 30 31
    """Linknet_ is a fully convolution neural network for fast image semantic segmentation

    Note:
        This implementation by default has 4 skip connections (original - 3).
qubvel's avatar
qubvel committed
32 33

    Args:
qubvel's avatar
qubvel committed
34 35 36 37 38
        backbone_name: name of classification model (without last dense layers) used as feature
                    extractor to build segmentation model.
        input_shape: shape of input data/image ``(H, W, C)``, in general
                case you do not need to set ``H`` and ``W`` shapes, just pass ``(None, None, C)`` to make your model be
                able to process images af any size, but ``H`` and ``W`` of input images should be divisible by factor ``32``.
qubvel's avatar
qubvel committed
39 40 41
        classes: a number of classes for output (output shape - ``(h, w, classes)``).
        activation: name of one of ``keras.activations`` for last model layer
            (e.g. ``sigmoid``, ``softmax``, ``linear``).
qubvel's avatar
qubvel committed
42
        encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
qubvel's avatar
qubvel committed
43 44
        encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
        encoder_features: a list of layer numbers or names starting from top of the model.
qubvel's avatar
qubvel committed
45 46 47
                    Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
                    layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
        decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks,
qubvel's avatar
qubvel committed
48
            for block with skip connection a number of filters is equal to number of filters in
qubvel's avatar
qubvel committed
49 50 51
            corresponding encoder block (estimates automatically and can be passed as ``None`` value).
        decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
                    is used.
qubvel's avatar
qubvel committed
52 53 54
        decoder_block_type: one of
                    - `upsampling`:  use ``Upsampling2D`` keras layer
                    - `transpose`:   use ``Transpose2D`` keras layer
qubvel's avatar
qubvel committed
55 56

    Returns:
qubvel's avatar
qubvel committed
57
        ``keras.models.Model``: **Linknet**
qubvel's avatar
qubvel committed
58

qubvel's avatar
qubvel committed
59 60
    .. _Linknet:
        https://arxiv.org/pdf/1707.03718.pdf
qubvel's avatar
qubvel committed
61 62 63 64 65 66 67
    """

    backbone = get_backbone(backbone_name,
                            input_shape=input_shape,
                            weights=encoder_weights,
                            include_top=False)

qubvel's avatar
qubvel committed
68
    if encoder_features == 'default':
qubvel's avatar
qubvel committed
69
        encoder_features = get_feature_layers(backbone_name, n=4)
qubvel's avatar
qubvel committed
70 71 72

    model = build_linknet(backbone,
                          classes,
qubvel's avatar
qubvel committed
73
                          encoder_features,
qubvel's avatar
qubvel committed
74
                          decoder_filters=decoder_filters,
qubvel's avatar
qubvel committed
75
                          upsample_layer=decoder_block_type,
qubvel's avatar
qubvel committed
76
                          activation=activation,
qubvel's avatar
qubvel committed
77
                          n_upsample_blocks=len(decoder_filters),
qubvel's avatar
qubvel committed
78
                          upsample_rates=(2, 2, 2, 2, 2),
qubvel's avatar
qubvel committed
79
                          upsample_kernel_size=(3, 3),
qubvel's avatar
qubvel committed
80 81 82
                          use_batchnorm=decoder_use_batchnorm)

    # lock encoder weights for fine-tuning
qubvel's avatar
qubvel committed
83
    if encoder_freeze:
qubvel's avatar
qubvel committed
84 85 86 87 88
        freeze_model(backbone)

    model.name = 'link-{}'.format(backbone_name)

    return model