Commit 0e864862 authored by Tyler-D's avatar Tyler-D

Add set_regularization

parent 3a33a066
......@@ -3,6 +3,7 @@ import warnings
import numpy as np
from functools import wraps
from keras.layers import BatchNormalization
from keras.models import model_from_json
def legacy_support(kwargs_map):
......@@ -95,9 +96,9 @@ def add_docstring(doc_string=None):
def recompile(model):
model.compile(model.optimizer, model.loss, model.metrics)
model.compile(model.optimizer, model.loss, model.metrics)
def freeze_model(model):
"""model all layers non trainable, excluding BatchNormalization layers"""
for layer in model.layers:
......@@ -131,3 +132,51 @@ def to_tuple(x):
return (x, x)
raise ValueError('Value should be tuple of length 2 or int value, got "{}"'.format(x))
def set_regularization(model,
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
beta_regularizer=None,
gamma_regularizer=None
):
"""Set regularizers to all layers
Note:
Returned model's config is updated correctly
Args:
model (``keras.models.Model``): instance of keras model
kernel_regularizer(``regularizer`): regularizer of kernels
bias_regularizer(``regularizer``): regularizer of bias
activity_regularizer(``regularizer``): regularizer of activity
gamma_regularizer(``regularizer``): regularizer of gamma of BatchNormalization
beta_regularizer(``regularizer``): regularizer of beta of BatchNormalization
Return:
out (``Model``): config updated model
"""
for layer in model.layers:
# set kernel_regularizer
if kernel_regularizer is not None and hasattr(layer, 'kernel_regularizer'):
layer.kernel_regularizer = kernel_regularizer
# set bias_regularizer
if bias_regularizer is not None and hasattr(layer, 'bias_regularizer'):
layer.bias_regularizer = bias_regularizer
# set activity_regularizer
if activity_regularizer is not None and hasattr(layer, 'activity_regularizer'):
layer.activity_regularizer = activity_regularizer
# set beta and gamma of BN layer
if beta_regularizer is not None and hasattr(layer, 'beta_regularizer'):
layer.beta_regularizer = beta_regularizer
if gamma_regularizer is not None and hasattr(layer, 'gamma_regularizer'):
layer.gamma_regularizer = gamma_regularizer
out = model_from_json(model.to_json())
out.set_weights(model.get_weights())
return out
import pytest
import numpy as np
# import keras.backend.tensorflow_backend as KTF
import keras.backend as K
# import tensorflow as tf
from keras import regularizers
from segmentation_models.utils import set_regularization
from segmentation_models import Unet
X1 = np.ones((1, 32, 32, 3))
Y1 = np.ones((1, 32, 32, 1))
MODEL = Unet
BACKBONE = 'resnet18'
CASE = (
(X1, Y1, MODEL, BACKBONE),
)
def _test_regularizer(model, reg_model, x, y):
def zero_loss(gt, pr):
return pr * 0
model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy'])
reg_model.compile('Adam', loss=zero_loss, metrics=['binary_accuracy'])
loss_1, _ = model.test_on_batch(x, y)
loss_2, _ = reg_model.test_on_batch(x, y)
assert loss_1 == 0
assert loss_2 > 0
K.clear_session()
@pytest.mark.parametrize('case', CASE)
def test_kernel_reg(case):
x, y, model_fn, backbone= case
l1_reg = regularizers.l1(0.1)
model = model_fn(backbone)
reg_model = set_regularization(model, kernel_regularizer=l1_reg)
_test_regularizer(model, reg_model, x, y)
l2_reg = regularizers.l2(0.1)
model = model_fn(backbone, encoder_weights=None)
reg_model = set_regularization(model, kernel_regularizer=l2_reg)
_test_regularizer(model, reg_model, x, y)
"""
Note:
backbone resnet18 use BN after each conv layer --- so no bias used in these conv layers
skip the bias regularizer test
@pytest.mark.parametrize('case', CASE)
def test_bias_reg(case):
x, y, model_fn, backbone = case
l1_reg = regularizers.l1(1)
model = model_fn(backbone)
reg_model = set_regularization(model, bias_regularizer=l1_reg)
_test_regularizer(model, reg_model, x, y)
l2_reg = regularizers.l2(1)
model = model_fn(backbone)
reg_model = set_regularization(model, bias_regularizer=l2_reg)
_test_regularizer(model, reg_model, x, y)
"""
@pytest.mark.parametrize('case', CASE)
def test_bn_reg(case):
x, y, model_fn, backbone= case
l1_reg = regularizers.l1(1)
model = model_fn(backbone)
reg_model = set_regularization(model, gamma_regularizer=l1_reg)
_test_regularizer(model, reg_model, x, y)
model = model_fn(backbone)
reg_model = set_regularization(model, beta_regularizer=l1_reg)
_test_regularizer(model, reg_model, x, y)
l2_reg = regularizers.l2(1)
model = model_fn(backbone)
reg_model = set_regularization(model, gamma_regularizer=l2_reg)
_test_regularizer(model, reg_model, x, y)
model = model_fn(backbone)
reg_model = set_regularization(model, beta_regularizer=l2_reg)
_test_regularizer(model, reg_model, x, y)
@pytest.mark.parametrize('case', CASE)
def test_activity_reg(case):
x, y, model_fn, backbone= case
l2_reg = regularizers.l2(1)
model = model_fn(backbone)
reg_model = set_regularization(model, activity_regularizer=l2_reg)
_test_regularizer(model, reg_model, x, y)
if __name__ == '__main__':
pytest.main([__file__])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment