from beyondml.tflow.layers import MaskedDense, MaskedConv2D, MaskedConv3D, FilterLayer, SumLayer, SelectorLayer, MultiMaskedDense, MultiMaskedConv2D, MultiMaskedConv3D, MultiDense, MultiConv2D, MultiConv3D, MultiMaxPool2D, MultiMaxPool3D, SparseDense, SparseConv2D, SparseConv3D, SparseMultiDense, SparseMultiConv2D, SparseMultiConv3D, MultitaskNormalization
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np
import warnings
MASKING_LAYERS = (MaskedDense, MaskedConv2D, MaskedConv3D,
MultiMaskedDense, MultiMaskedConv2D, MultiMaskedConv3D)
MULTI_MASKING_LAYERS = (MultiMaskedDense, MultiMaskedConv2D, MultiMaskedConv3D)
NON_MASKING_LAYERS = (MultiDense, MultiConv2D, MultiConv3D)
SPARSE_LAYERS = (SparseDense, SparseConv2D, SparseConv3D,
SparseMultiDense, SparseMultiConv2D, SparseMultiConv3D)
CUSTOM_LAYERS = MASKING_LAYERS + NON_MASKING_LAYERS + SPARSE_LAYERS + \
(FilterLayer, SumLayer, SelectorLayer, MultiMaxPool2D,
MultiMaxPool3D, MultitaskNormalization)
[docs]class ActiveSparsification(Callback):
"""
Keras-compatible callback object which enables active sparsification, allowing for increased sparsification as models
train.
"""
def __init__(
self,
performance_cutoff,
performance_measure='auto',
starting_sparsification=None,
max_sparsification=99,
sparsification_rate=1,
sparsification_patience=10,
stopping_delta=0.01,
stopping_patience=5,
restore_best_weights=True,
verbose=1
):
"""
Parameters
----------
performance_cutoff : float
The cutoff value that the performance measure must "beat" in order to iteratively sparsify
performance_measure : str (default 'auto')
The performance measure that is used in conjunction with `performance_cutoff`
starting_sparsification : int or None (default None)
The starting sparsification that the model has already been sparsified to. If `None`, then defaults to 0
max_sparsification : int (default 99)
The maximum sparsification allowed
sparsification_rate : int (default 1)
The increase in sparsification that occurs when model performance beats the performance cutoff
sparsification_patience : int (default 10)
The number of epochs the model is allowed to train for without beating the performance measure before stopping sparsification
stopping_delta : float (default 0.01)
The performance improvement that must be seen when pruning has stopped and early stopping is being considered
stopping_patience : int (default 5)
The number of epochs the model is allowed to train for without `stopping_delta` improvement before stopping training
restore_best_weights : bool (default True)
Whether to restore model best weights after training
verbose : int or bool (default 1)
Verbosity level for logging.
Notes
-----
- If `performance_measure` is 'auto', defaults to the following measures, in order: 'val_accuracy', 'val_loss', 'accuracy', 'loss'
- If `performance_measure` defaults to any accuracy value, then `performance_cutoff` represents the minimum value that must be
beaten. If `performance_measure` defaults to any loss value, then `performance_cutoff` represents the maximum value that must
be beaten
"""
super(Callback, self).__init__()
self.performance_cutoff = performance_cutoff
self.performance_measure = performance_measure
self.starting_sparsification = starting_sparsification
self.max_sparsification = max_sparsification
self.sparsification_rate = sparsification_rate
self.sparsification_patience = sparsification_patience
self.stopping_delta = stopping_delta
self.stopping_patience = stopping_patience
self.restore_best_weights = restore_best_weights
self.verbose = int(verbose)
[docs] def on_train_begin(self, logs=None):
self.prune_wait = 0
self.stop_wait = 0
self.best_weights = self.model.get_weights()
self.best = None
self.pruning = True
self.sparsification = self.starting_sparsification if self.starting_sparsification is not None else 0
[docs] def on_epoch_end(self, epoch, logs=None):
if self.performance_measure == 'auto':
if 'val_accuracy' in logs.keys():
self.performance_measure = 'val_accuracy'
elif 'val_loss' in logs.keys():
self.performance_measure = 'val_loss'
elif 'accuracy' in logs.keys():
self.performance_measure = 'accuracy'
else:
self.performance_measure = 'loss'
if self.verbose:
print(f'Performance measure set to {self.performance_measure}')
performance = logs[self.performance_measure]
if self.best is None:
self.best = performance
if 'accuracy' in self.performance_measure:
if self.pruning:
if performance >= self.performance_cutoff:
self.best_weights = self.model.get_weights()
self.best = performance
if self.sparsification + self.sparsification_rate > self.max_sparsification:
print(
'Model cannot be sparsified further due to max sparsification parameter')
self.pruning = False
else:
self._sparsify_model(
self.sparsification + self.sparsification_rate)
self.sparsification = self.sparsification + self.sparsification_rate
self.prune_wait = 0
if self.verbose:
print(
f'Model performance reached {round(performance, 2)}, sparsifying to {self.sparsification}')
else:
self.prune_wait += 1
if self.verbose:
print(
f'Model performance has not reached pruning threshold for {self.prune_wait} epoch(s)')
if self.prune_wait >= self.sparsification_patience:
self.pruning = False
self.model.set_weights(self.best_weights)
if self.verbose:
print(
f'Model performance has not reached pruning threshold for {self.prune_wait} epochs, reverting to {self.sparsification - self.sparsification_rate} sparsification and beginning early stopping')
else:
if performance >= self.best + self.stopping_delta:
self.best_weights = self.model.get_weights()
self.best = performance
self.stop_wait = 0
if self.verbose:
print(
f'Model performance improved to {round(self.best, 2)}')
else:
self.stop_wait += 1
if self.verbose:
print(
f'Early stopping performance has not met threshold for {self.stop_wait} epochs')
if self.stop_wait >= self.stopping_patience:
if self.restore_best_weights:
self.model.set_weights(self.best_weights)
if self.verbose:
print(
'Model performance has not met early stopping criteria. Stopping training')
self.model.stop_training = True
else:
if self.pruning:
if performance <= self.performance_cutoff:
self.best_weights = self.model.get_weights()
self.best = performance
if self.sparsification + self.sparsification_rate > self.max_sparsification:
print(
'Model cannot be sparsified further due to max sparsification parameter')
self.pruning = False
else:
self._sparsify_model(
self.sparsification + self.sparsification_rate)
self.sparsification = self.sparsification + self.sparsification_rate
self.prune_wait = 0
if self.verbose:
print(
f'Model performance reached {round(performance, 2)}, sparsifying to {self.sparsification}')
else:
self.prune_wait += 1
if self.verbose:
print(
f'Model performance has not reached pruning threshold for {self.prune_wait} epoch(s)')
if self.prune_wait >= self.sparsification_patience:
self.pruning = False
self.model.set_weights(self.best_weights)
if self.verbose:
print(
f'Model performance has not reached pruning threshold for {self.prune_wait} epochs, reverting to {self.sparsification - self.sparsification_rate} sparsification and beginning early stopping')
else:
if performance <= self.best - self.stopping_delta:
self.best_weights = self.model.get_weights()
self.best = performance
self.stop_wait = 0
if self.verbose:
print(
f'Model performance improved to {round(self.best, 2)}')
else:
self.stop_wait += 1
if self.verbose:
print(
f'Early stopping performance has not met threshold for {self.stop_wait} epochs')
if self.stop_wait >= self.stopping_patience:
if self.restore_best_weights:
self.model.set_weights(self.best_weights)
if self.verbose:
print(
'Model performance has not met early stopping criteria. Stopping training')
self.model.stop_training = True
def _sparsify_model(self, percentage):
"""Function to sparsify the model"""
try:
new_model = tf.keras.models.Model.from_config(
self.model.get_config(), custom_objects=get_custom_objects())
except Exception:
new_model = tf.keras.models.Sequential.from_config(
self.model.get_config(), custom_objects=get_custom_objects())
new_model.set_weights(self.model.get_weights())
self.model.set_weights(
mask_model(
new_model,
percentage,
method='magnitude'
).get_weights()
)
def _get_masking_gradients(
model,
x,
y
):
"""
Obtain masking layer gradients with respect to the tasks presented
Parameters
----------
model : tf.keras Model
The model to get the gradients of
x : np.array or array-like
The input data
y : np.array or array-like
The true output
Returns
-------
masking_gradients : list
A list of gradients for the masking weights for the model
"""
# Check outputs
if isinstance(y, list):
if not all([len(val.shape) > 1 for val in y]):
raise ValueError(
'Error in output shapes. If any tasks have a single output, please reshape the value using the `.reshape(-1, 1)` method')
elif not len(y.shape) > 1:
raise ValueError(
'Error in output shapes. If your task has a single output, please reshape the value using the `.reshape(-1, 1)` method')
# Grab the weights for the masking layers
masking_weights = [
layer.trainable_weights for layer in model.layers if isinstance(layer, MASKING_LAYERS)
]
# Setup and obtain the losses
losses = model.loss
if not isinstance(losses, list):
if callable(losses):
losses = [losses] * len(x)
losses = [tf.keras.losses.get(losses)] * len(x)
else:
losses = [tf.keras.losses.get(loss) if not callable(
loss) else loss for loss in losses]
# Grab the gradients for the specified weights
with tf.GradientTape() as tape:
raw_preds = model(x)
losses = [losses[i](y[i], raw_preds[i]) for i in range(len(losses))]
gradients = tape.gradient(losses, masking_weights)
return gradients
[docs]def get_custom_objects():
"""Return a dictionary of custom objects (layers) to use when loading models trained using this package"""
return dict(
zip(
['MaskedDense', 'MaskedConv2D', 'MaskedConv3D', 'MultiMaskedDense', 'MultiMaskedConv2D', 'MultiMaskedConv3D', 'MultiDense',
'MultiConv2D', 'MultiConv3D', 'SparseDense', 'SparseConv2D', 'SparseConv3D', 'SparseMultiDense', 'SparseMultiConv2D', 'SparseMultiConv3D',
'FilterLayer', 'SumLayer', 'SelectorLayer', 'MultiMaxPool2D', 'MultiMaxPool3D'],
CUSTOM_LAYERS
)
)
[docs]def mask_model(
model,
percentile,
method='gradients',
exclusive=True,
x=None,
y=None
):
"""
Mask the multitask model for training respective using the gradients for the tasks at hand
Parameters
----------
model : keras model with MANN masking layers
The model to be masked
percentile : int
Percentile to use in masking. Any weights less than the `percentile` value will be made zero
method : str (default 'gradients')
One of either 'gradients' or 'magnitude' - the method for how to identify weights to mask
If method is 'gradients', utilizes the gradients with respect to the passed x and y variables
to identify the subnetwork to activate for each task
If method is 'magnitude', uses the magnitude of the weights to identify the subnetwork to activate for each task
exclusive : bool (default True)
Whether to restrict previously-used weight indices for each task. If `True`, this identifies disjoint subsets of
weights within the layer which perform the tasks requested.
x : list of np.ndarray or array-like
The training data input values, ignored if "method" is 'magnitude'
y : list of np.ndarray or array-like
The training data output values, ignored if "method" is 'magnitude'
"""
# Check method
method = method.lower()
if method not in ['gradients', 'magnitude']:
raise ValueError(
f"method must be one of 'gradients', 'magnitude', got {method}")
# Get the gradients
if method == 'gradients':
grads = _get_masking_gradients(
model,
x,
y
)
# Work to identify the right weights if exclusive
if exclusive:
gradient_idx = 0
for layer in model.layers:
if isinstance(layer, tf.keras.models.Model):
warnings.warn(
'mask_model does not effectively support models with models as layers if method is "gradients". Please set method to "magnitude"', RuntimeWarning)
if isinstance(layer, MASKING_LAYERS):
if not isinstance(layer, MULTI_MASKING_LAYERS):
layer_grads = [np.abs(grad)
for grad in grads[gradient_idx]]
new_masks = [(grad >= np.percentile(grad, percentile)).astype(
int) for grad in layer_grads]
layer.set_masks(new_masks)
else:
layer_grads = [np.abs(grad.numpy())
for grad in grads[gradient_idx]]
new_masks = []
for grad in layer_grads:
new_mask = np.zeros(grad.shape)
used_weights = np.zeros(grad.shape[1:])
for task_idx in range(grad.shape[0]):
grad[task_idx][used_weights == 1] = 0
new_mask[task_idx] = (grad[task_idx] >= np.percentile(
grad[task_idx], percentile)).astype(int)
used_weights += new_mask[task_idx]
new_masks.append(new_mask)
layer.set_masks(new_masks)
gradient_idx += 1
# Work to identify the right weights if not exclusive
else:
gradient_idx = 0
for layer in model.layers:
if isinstance(layer, tf.keras.models.Model):
warnings.warn(
'mask_model does not effectively support models with models as layers if method is "gradients". Please set method to "magnitude"', RuntimeWarning)
if isinstance(layer, MASKING_LAYERS):
if not isinstance(layer, MULTI_MASKING_LAYERS):
layer_grads = [np.abs(grad.numpy())
for grad in grads[gradient_idx]]
new_masks = [(grad >= np.percentile(grad, percentile)).astype(
int) for grad in layer_grads]
layer.set_masks(new_masks)
else:
layer_grads = [np.abs(grad.numpy())
for grad in grads[gradient_idx]]
new_masks = []
for grad in layer_grads:
new_mask = np.zeros(grad.shape)
for task_idx in range(grad.shape[0]):
new_mask[task_idx] = (grad[task_idx] >= np.percentile(
grad[task_idx], percentile)).astype(int)
new_masks.append(new_mask)
layer.set_masks(new_masks)
gradient_idx += 1
# Do this is method is "magnitude"
elif method == 'magnitude':
for layer in model.layers:
if isinstance(layer, MASKING_LAYERS):
if not isinstance(layer, MULTI_MASKING_LAYERS):
weights = [np.abs(weight.numpy())
for weight in layer.trainable_weights]
new_masks = [
(weight >= np.percentile(weight, percentile)).astype(int) for weight in weights
]
layer.set_masks(new_masks)
else:
weights = [np.abs(weight.numpy())
for weight in layer.trainable_weights]
if not exclusive:
new_masks = [np.zeros(weight.shape)
for weight in weights]
for weight_idx in range(len(weights)):
for task_idx in range(weights[weight_idx].shape[0]):
new_masks[weight_idx][task_idx] = (weights[weight_idx][task_idx] >= np.percentile(
weights[weight_idx][task_idx], percentile)).astype(int)
else:
new_masks = [np.zeros(weight.shape)
for weight in weights]
for weight_idx in range(len(weights)):
for task_idx in range(weights[weight_idx].shape[0]):
exclusive_weight = weights[weight_idx][task_idx] * (
1 - new_masks[weight_idx][:task_idx].sum(axis=0))
new_masks[weight_idx][task_idx] = (exclusive_weight >= np.percentile(
weights[weight_idx][task_idx], percentile)).astype(int)
layer.set_masks(new_masks)
elif isinstance(layer, tf.keras.models.Model):
mask_model(
layer,
percentile,
method,
exclusive
)
# Compile the model again so the effects take place
model.compile()
return model
[docs]def replace_config(config):
"""
Replace the model config to remove masking layers
"""
new_config = config.copy()
layer_mapping = {
'MaskedConv2D': 'Conv2D',
'MaskedConv3D': 'Conv3D',
'MaskedDense': 'Dense',
'MultiMaskedConv2D': 'MultiConv2D',
'MultiMaskedConv3D': 'MultiConv3D',
'MultiMaskedDense': 'MultiDense'
}
model_classes = ('Functional', 'Sequential')
for i in range(len(new_config['layers'])):
if new_config['layers'][i]['class_name'] in layer_mapping.keys():
orig_class_name = new_config['layers'][i]['class_name']
new_config['layers'][i]['class_name'] = layer_mapping[
new_config['layers'][i]['class_name']
]
if new_config['layers'][i].get('module'):
new_config['layers'][i]['module'] = new_config['layers'][i]['module'].replace(
orig_class_name, layer_mapping[orig_class_name]
).replace('beyondml.tflow.layers', 'tensorflow.keras.layers')
if new_config['layers'][i]['config'].get('mask_initializer'):
del new_config['layers'][i]['config']['mask_initializer']
elif new_config['layers'][i]['class_name'] in model_classes:
new_config['layers'][i]['config'] = replace_config(
new_config['layers'][i]['config'])
return new_config
def _create_masking_config(config):
"""
Replace the model config to add masking layers
"""
new_config = config.copy()
layer_mapping = {
'Conv2D': 'MaskedConv2D',
'Conv3D': 'MaskedConv3D',
'Dense': 'MaskedDense',
'MultiConv2D': 'MultiMaskedConv2D',
'MultiConv3D': 'MultiMaskedConv3D',
'MultiDense': 'MultiMaskedDense'
}
model_classes = ('Functional', 'Sequential')
for i in range(len(new_config['layers'])):
if new_config['layers'][i]['class_name'] in layer_mapping.keys():
orig_class_name = new_config['layers'][i]['class_name']
new_config['layers'][i]['class_name'] = layer_mapping[
new_config['layers'][i]['class_name']
]
if new_config['layers'][i].get('module'):
new_config['layers'][i]['module'] = new_config['layers'][i]['module'].replace(
orig_class_name, layer_mapping[orig_class_name]
)
new_config['layers'][i]['config']['mask_initializer'] = tf.keras.initializers.serialize(
tf.keras.initializers.get('ones')
)
elif new_config['layers'][i]['class_name'] in model_classes:
new_config['layers'][i]['config'] = _create_masking_config(
new_config['layers'][i]['config'])
return new_config
def _quantize_model_config(config, dtype='float16'):
"""
Change the dtype of the model
"""
model_classes = ('Functional', 'Sequential')
new_config = config.copy()
for i in range(len(new_config['layers'])):
if new_config['layers'][i]['class_name'] in model_classes:
new_config['layers'][i] = _quantize_model_config(
new_config['layers'][i]['config'], dtype)
else:
new_config['layers'][i]['config']['dtype'] = dtype
return new_config
[docs]def replace_weights(new_model, old_model):
"""
Replace the weights of a newly created model with the weights (sans masks) of an old model
"""
for i in range(len(new_model.layers)):
# Recursion in case the model contains other models
if isinstance(new_model.layers[i], tf.keras.models.Model):
replace_weights(new_model.layers[i], old_model.layers[i])
# If not masking layers, simply replace weights
elif not isinstance(old_model.layers[i], MASKING_LAYERS):
new_model.layers[i].set_weights(old_model.layers[i].get_weights())
# If masking layers, replace only the required weights
else:
n_weights = len(new_model.layers[i].get_weights())
new_model.layers[i].set_weights(
old_model.layers[i].get_weights()[:n_weights])
# Compile and return the model
new_model.compile()
return new_model
def _replace_masking_weights(new_model, old_model):
"""
Replace the weights of a newly created model with the weights (adding masks) of an old model
"""
for i in range(len(new_model.layers)):
# Recursion in case the model contains other models
if isinstance(new_model.layers[i], tf.keras.models.Model):
_replace_masking_weights(new_model.layers[i], old_model.layers[i])
# If not masking layers, simply replace weights
elif not isinstance(new_model.layers[i], MASKING_LAYERS):
new_model.layers[i].set_weights(old_model.layers[i].get_weights())
# If masking layers, replace the weights and have all ones as the masks
else:
n_weights = len(old_model.layers[i].get_weights())
weights = old_model.layers[i].get_weights()
weights.extend(new_model.layers[i].get_weights()[n_weights:])
new_model.layers[i].set_weights(weights)
# Compile and return the model
new_model.compile()
return new_model
[docs]def add_layer_masks(model, additional_custom_objects=None):
"""
Convert a trained model from one that does not have masking weights to one that does have
masking weights
Parameters
----------
model : TensorFlow Keras model
The model to be converted
additional_custom_objects : dict or None (default None)
Additional custom layers to use
Returns
-------
new_model : TensorFlow Keras model
The converted model
"""
custom_objects = get_custom_objects()
if additional_custom_objects is not None:
custom_objects.update(additional_custom_objects)
# Replace the config of the model
config = model.get_config()
new_config = _create_masking_config(config)
# Create the new model
try:
new_model = tf.keras.models.Model().from_config(
new_config,
custom_objects=custom_objects
)
except Exception:
new_model = tf.keras.models.Sequential().from_config(
new_config,
custom_objects=custom_objects
)
# Replace the weights of the new model
new_model = _replace_masking_weights(new_model, model)
# Compile and return the model
new_model.compile()
return new_model
[docs]def quantize_model(model, dtype='float16', additional_custom_objects=None):
"""
Apply model quantization
Parameters
----------
model : TensorFlow Keras Model
The model to quantize
dtype : str or TensorFlow datatype (default 'float16')
The datatype to quantize to
additional_custom_objects : None or dict (default None)
Additional custom objects to use to instantiate the model
Returns
-------
new_model : TensorFlow Keras Model
The quantized model
"""
# Grab the configuration from the original model
model_config = model.get_config()
# Grab the weights from the original model as well
weights = model.get_weights()
# Change the weights to have the new datatype
new_weights = [
np.array(w, dtype=dtype) for w in weights
]
# Change the config to get the quantized configuration
new_config = _quantize_model_config(model_config, dtype)
# Instantiate the new model from the new config
custom_objects = get_custom_objects()
if additional_custom_objects is not None:
custom_objects.update(additional_custom_objects)
try:
new_model = tf.keras.models.Model.from_config(
new_config, custom_objects=custom_objects)
except Exception:
new_model = tf.keras.models.Sequential.from_config(
new_config, custom_objects=custom_objects)
# Set the weights of the new model
new_model.set_weights(new_weights)
return new_model
def _get_masking_weights(model):
"""
Get the masking weights of a model
Parameters
----------
model : TensorFlow Keras model
The model to get the masking weights of
Returns
-------
weights : list of TensorFlow tensors
The requested weights
"""
return [
layer.weights for layer in model.layers if isinstance(layer, MASKING_LAYERS)
]
[docs]def get_task_masking_gradients(
model,
task_num
):
"""
Get the gradients of masking weights within a model
Parameters
----------
model : TensorFlow Keras model
The model to retrieve the gradients of
Notes
-----
- This function should only be run *before* the model has been trained
or used to predict. There is an unknown bug related to TensorFlow which
is leading to incorrect results after initial training
- When running this function, randomized input and output data is sent
through the model to retrieve gradients respective to each task. If
the model is compiled using `sparse_categorical_crossentropy' loss,
this will break this function's functionality. As a result, please
use `categorical_crossentropy` (or even better, `mse`) before running this function. After
retrieving gradients, the model can be recompiled with whatever parameters are desired.
Returns
-------
gradients : list of TensorFlow tensors
The gradients of the masking weights of the model
"""
# Figure out the number of tasks
output_shapes = model.output_shape
if isinstance(output_shapes, list):
num_tasks = len(output_shapes)
else:
num_tasks = 1
# Get the loss weights
if num_tasks > 1:
loss_weights = [0] * num_tasks
loss_weights[task_num] = 1
# Get the masking weights
masking_weights = _get_masking_weights(model)
# Configure inputs
inputs = []
input_shapes = model.input_shape
if isinstance(input_shapes, list):
for shape in input_shapes:
new_shape = list(shape)
for i in range(len(new_shape)):
if new_shape[i] is None:
new_shape[i] = 1
inputs.append(np.random.random(new_shape))
else:
new_shape = list(input_shapes)
for i in range(len(new_shape)):
if new_shape[i] is None:
new_shape[i] = 1
inputs.append(np.random.random(new_shape))
# Configure outputs
outputs = []
output_shapes = model.output_shape
if isinstance(output_shapes, list):
for shape in output_shapes:
new_shape = list(shape)
for i in range(len(new_shape)):
if new_shape[i] is None:
new_shape[i] = 1
outputs.append(np.random.random(new_shape))
else:
new_shape = list(output_shapes)
for i in range(len(new_shape)):
if new_shape[i] is None:
new_shape[i] = 1
outputs.append(np.random.random(new_shape))
# Configure the losses
losses = model.loss
if not isinstance(losses, list):
losses = [losses] * num_tasks
losses = [
tf.keras.losses.get(loss) for loss in losses
]
# Get the gradients of the weights wrt the task
with tf.GradientTape() as tape:
raw_preds = model(inputs)
loss_values = [losses[i](outputs[i], raw_preds[i]) * loss_weights[i]
for i in range(len(losses))]
gradients = tape.gradient(loss_values, masking_weights)
return gradients
[docs]def mask_task_weights(
model,
task_masking_gradients,
percentile,
respect_previous_tasks=True
):
"""
Parameters
----------
model : TensorFlow Keras model
The model to be masked
task_masking_gradients : list of TensorFlow tensors
The gradients for the specific task requested
percentile : int
The percentile to mask/prune
respect_previous_tasks : bool (default True)
Whether to respect the weights used for previous tasks and not use them
for subsequent tasks
Returns
-------
masked_model : TensorFlow Keras model
The masked model
"""
# Get the actual weights to be able to set them
masking_weights = [
layer.get_weights() for layer in model.layers if isinstance(layer, MASKING_LAYERS)
]
# Iterate through each of the layers of the model, keeping track of the index of which masking layer has been achieved
masking_idx = 0
for layer in model.layers:
if isinstance(layer, MASKING_LAYERS):
# Set the new masks to be masked
new_masks = []
# Different procedures if multi masking layer vs single masking layer
if isinstance(layer, MULTI_MASKING_LAYERS):
# Check for all of the weights in the list of weights (corresponding to gradients)
for weight_num in range(len(task_masking_gradients[masking_idx])):
# Set the weight and gradient values so it's easier to follow
weight = masking_weights[masking_idx][weight_num]
gradient = task_masking_gradients[masking_idx][weight_num]
# If gradient is None, then the value is a mask
task_idx_num = None
if gradient is not None:
# Figure out which task index is the right one
for task_idx in range(gradient.shape[0]):
if not (gradient[task_idx].numpy() == 0).all():
task_idx_num = task_idx
if task_idx_num is not None:
# Get the new weight for that task only
task_weight = np.abs(weight[task_idx_num])
# Enforce respecting previous-task weights
if respect_previous_tasks and task_idx_num > 0:
task_weight[(weight[:task_idx_num] != 0).astype(
int).sum(axis=0).astype(bool)] = 0
# Get the new mask
weight_mask = (task_weight >= np.percentile(
task_weight, percentile))
# Find the existing mask and set the value of only the task-specific part
layer_mask = masking_weights[masking_idx][weight_num + int(
len(masking_weights[masking_idx]) / 2)]
layer_mask[task_idx_num] = weight_mask
# Append the new mask
new_masks.append(layer_mask)
# If the layer is a single masking layer
else:
for weight_num in range(len(task_masking_gradients[masking_idx])):
# Assign the weight and the gradient for this specific layer, for sanity
weight = masking_weights[masking_idx][weight_num]
gradient = task_masking_gradients[masking_idx][weight_num]
# If gradient is None, then the weight is a mask
if gradient is not None:
# Only proceed if the gradient exists
if not (gradient.numpy() == 0).all():
weight = np.abs(weight)
weight_mask = (
weight >= np.percentile(weight, percentile))
new_masks.append(weight_mask)
# If new masks have been identified (it's possible that this did not occur), set the new masks for that layer
if new_masks != []:
layer.set_masks(new_masks)
# Lastly, increase the masking index by 1
masking_idx += 1
# Compile the model again and return it
model.compile()
return model
[docs]def train_model_iteratively(
model,
task_gradients,
train_x,
train_y,
validation_split,
delta,
batch_size,
losses,
optimizer='adam',
metrics=None,
starting_pruning=0,
pruning_rate=10,
patience=5,
max_epochs=100
):
"""
Train a model iteratively on each task, first obtaining
baseline performance on each task and then iteratively
training and pruning each task as far back as possible while
maintaining acceptable performance on each task
Parameters
----------
model : TensorFlow Keras model
The model to be trained
task_gradients : list of TensorFlow tensors
Gradients for each task, output from the `get_task_masking_gradients` function
train_x : list of numpy arrays, TensorFlow Datasets, or other
data types models can train with
The input data to use to train on
train_y : list of numpy arrays, TensorFlow Datasets, or other
data types model can train with
The output data to use to train on
validation_split : float, or list of float
The proportion of data to use for validation
delta : float
The tolerance between validation losses to be considered "acceptable"
performance to continue
batch_size : int
The batch size to train with
losses : str, list, or Keras loss function
The loss or losses to use when training
optimizer : str, list, or Keras optimizer
The optimizer to use when training (default 'adam')
starting_pruning : int or list of int (default 0)
The starting pruning rate to use for each task
pruning_rate : int or list of int (default [10, 5, 2, 1])
The pruning rate to use
patience : int (default 5)
The patience for number of epochs to wait for performance to improve sufficiently
max_epochs : int or list of int (default 100)
The maximum number of epochs to use for training each task
"""
# Get some information about the training procedure, including the number of tasks
# and the gradients for each task
num_tasks = len(train_x)
# Keep track of the amount of the model currently used
amount_used = 100
# Start the training iterations
for task_num in range(num_tasks):
print(f'Training task {task_num}')
# Get the starting task pruning rate for the current task
if isinstance(starting_pruning, int):
task_start_pruning = starting_pruning
else:
task_start_pruning = starting_pruning[task_num]
# Get the current pruning rate
if isinstance(pruning_rate, int):
current_pruning_rate = pruning_rate
else:
current_pruning_rate = pruning_rate[task_num]
# Configure the current validation split
if isinstance(validation_split, float):
current_validation_split = validation_split
else:
current_validation_split = validation_split[task_num]
# Configure the loss weights
loss_weights = [0] * num_tasks
loss_weights[task_num] = 1
# Configure the epochs
if isinstance(max_epochs, int):
current_epochs = max_epochs
else:
current_epochs = max_epochs[task_num]
# Compile the model
model.compile(
loss=losses,
optimizer=optimizer,
loss_weights=loss_weights,
metrics=metrics
)
# Train the model initially
callback = tf.keras.callbacks.EarlyStopping(
min_delta=delta,
patience=patience,
restore_best_weights=True
)
history = model.fit(
train_x[task_num],
train_y[task_num],
batch_size=batch_size,
epochs=current_epochs,
validation_split=current_validation_split,
callbacks=[callback],
verbose=2
)
# Retrieve the validation loss and current best weights
best_loss = min(history.history['val_loss'])
best_weights = model.get_weights()
# Training loop for the task at hand
current_wait = 0
if task_num == 0:
current_prune = task_start_pruning
else:
current_prune = max(task_start_pruning, amount_used)
keep_training = True
just_started = True
# If pruning needs to occur, do it
if current_prune != 0:
print(f'Pruning task to {current_prune}')
model = mask_task_weights(
model,
task_gradients[task_num],
current_prune
)
# keep_training indicates that training is to occur
while keep_training:
# First prune the model to the next pruning rate
if current_prune + current_pruning_rate < 100:
# Check if the training just started
if not just_started:
# Increase the pruning rate
current_prune += current_pruning_rate
model = mask_task_weights(
model,
task_gradients[task_num],
current_prune
)
print(f'Pruning task to {current_prune}')
else:
just_started = False
# Recompile the model
model.compile(
loss=losses,
optimizer=optimizer,
loss_weights=loss_weights,
metrics=metrics
)
# Train the model with the new pruning rate
while current_wait < patience:
# Fit the model for a single epoch
history = model.fit(
train_x[task_num],
train_y[task_num],
batch_size=batch_size,
validation_split=current_validation_split,
verbose=2
)
# Get the new loss
loss = history.history['val_loss'][-1]
# If loss is within acceptable range, grab the best weights and
# reassign the best loss. Otherwise, increase current wait
if loss < best_loss + delta:
best_weights = model.get_weights()
best_loss = loss
break
else:
current_wait += 1
# If pruning was not successful, restore best pruning rate
if current_wait == patience or current_prune + current_pruning_rate >= 100:
keep_training = False
else:
keep_training = False
# Record how much of the model has been used
if task_num == 0:
amount_used -= current_prune
else:
amount_used += 100 - current_prune
# Now that current wait has been reached, restore best weights
model.set_weights(best_weights)
# Recompile the model
model.compile(
loss=losses,
optimizer=optimizer,
loss_weights=loss_weights,
metrics=metrics
)
# Fit using the new best weights
model.fit(
train_x[task_num],
train_y[task_num],
batch_size=batch_size,
epochs=current_epochs,
validation_split=current_validation_split,
callbacks=[callback],
verbose=2
)
return model
[docs]def train_model(
model,
train_x,
train_y,
loss,
metrics,
optimizer,
cutoff,
batch_size=32,
epochs=100,
starting_sparsification=0,
max_sparsification=99,
sparsification_rate=5,
sparsification_patience=10,
stopping_patience=5
):
model.compile(loss=loss, metrics=metrics, optimizer=optimizer)
callback = ActiveSparsification(
cutoff,
starting_sparsification=starting_sparsification,
max_sparsification=max_sparsification,
sparsification_rate=sparsification_rate,
sparsification_patience=sparsification_patience,
stopping_patience=stopping_patience
)
model.fit(
train_x,
train_y,
batch_size=batch_size,
epochs=epochs,
callbacks=[callback]
)
return model