fixing SequenceSummary head in TF 2.0
This commit is contained in:
@@ -394,26 +394,26 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
|||||||
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
self.summary = None
|
self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj
|
||||||
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
if self.has_summary:
|
||||||
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
||||||
num_classes = config.num_labels
|
num_classes = config.num_labels
|
||||||
else:
|
else:
|
||||||
num_classes = config.hidden_size
|
num_classes = config.hidden_size
|
||||||
self.summary = tf.keras.layers.Dense(num_classes,
|
self.summary = tf.keras.layers.Dense(num_classes,
|
||||||
kernel_initializer=get_initializer(initializer_range),
|
kernel_initializer=get_initializer(initializer_range),
|
||||||
name='summary')
|
name='summary')
|
||||||
|
|
||||||
self.activation = None
|
self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh'
|
||||||
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
if self.has_activation:
|
||||||
self.activation = tf.keras.activations.tanh
|
self.activation = tf.keras.activations.tanh
|
||||||
|
|
||||||
self.first_dropout = None
|
self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0
|
||||||
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
if self.has_first_dropout:
|
||||||
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
|
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
|
||||||
|
|
||||||
self.last_dropout = None
|
self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0
|
||||||
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
if self.has_last_dropout:
|
||||||
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, inputs, training=False):
|
||||||
@@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
|||||||
elif self.summary_type == 'attn':
|
elif self.summary_type == 'attn':
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
if training and self.first_dropout is not None:
|
if self.has_first_dropout:
|
||||||
output = self.first_dropout(output)
|
output = self.first_dropout(output, training=training)
|
||||||
|
|
||||||
if self.summary is not None:
|
if self.has_summary:
|
||||||
output = self.summary(output)
|
output = self.summary(output)
|
||||||
|
|
||||||
if self.activation is not None:
|
if self.has_activation:
|
||||||
output = self.activation(output)
|
output = self.activation(output)
|
||||||
|
|
||||||
if training and self.last_dropout is not None:
|
if self.has_last_dropout:
|
||||||
output = self.last_dropout(output)
|
output = self.last_dropout(output, training=training)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user