from tensorflow.keras.layers import Layer
import tensorflow as tf
[docs]class SparseMultiDense(Layer):
"""
Sparse implementation of the MultiDense layer. If used in a model, must be saved and loaded via pickle
"""
def __init__(
self,
weight,
bias,
activation=None,
**kwargs
):
"""
Parameters
----------
weight : tf.Tensor
The kernel tensor
bias : tf.Tensor
The bias tensor
activation : None, str or keras activation function (default None)
The activation function to use
"""
super().__init__(**kwargs)
self.w = {
i: tf.sparse.from_dense(weight[i]) for i in range(weight.shape[0])
}
self.b = {
i: tf.sparse.from_dense(bias[i]) for i in range(bias.shape[0])
}
self.activation = tf.keras.activations.get(activation)
[docs] def build(self, input_shape):
"""
Build the layer in preparation to be trained or called. Should not be called directly,
but rather is called when the layer is added to a model
"""
pass
[docs] def call(self, inputs):
"""
This is where the layer's logic lives and is called upon inputs
Parameters
----------
inputs : TensorFlow Tensor or Tensor-like
The inputs to the layer
Returns
-------
outputs : TensorFlow Tensor
The outputs of the layer's logic
"""
output_tensor = [
tf.matmul(inputs[i], tf.sparse.to_dense(self.w[i])) + tf.sparse.to_dense(self.b[i]) for i in range(len(inputs))
]
return [
self.activation(tensor) for tensor in output_tensor
]
[docs] def get_config(self):
config = super().get_config().copy()
config['activation'] = tf.keras.activations.serialize(self.activation)
return config
[docs] @classmethod
def from_layer(cls, layer):
"""
Create a layer from an instance of another layer
"""
weights = layer.get_weights()
w = weights[0]
b = weights[1]
activation = layer.activation
return cls(
w,
b,
activation
)
[docs] @classmethod
def from_config(cls, config):
return cls(**config)