From c9e8c51946aefffd932068bc801099e7670d8c20 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 10 Oct 2019 15:16:05 +0200 Subject: [PATCH] fixing SequenceSummary head in TF 2.0 --- transformers/modeling_tf_utils.py | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/transformers/modeling_tf_utils.py b/transformers/modeling_tf_utils.py index 3a576345f5..a96e2765fd 100644 --- a/transformers/modeling_tf_utils.py +++ b/transformers/modeling_tf_utils.py @@ -394,26 +394,26 @@ class TFSequenceSummary(tf.keras.layers.Layer): # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError - self.summary = None - if hasattr(config, 'summary_use_proj') and config.summary_use_proj: + self.has_summary = hasattr(config, 'summary_use_proj') and config.summary_use_proj + if self.has_summary: if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = tf.keras.layers.Dense(num_classes, - kernel_initializer=get_initializer(initializer_range), - name='summary') + kernel_initializer=get_initializer(initializer_range), + name='summary') - self.activation = None - if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh': + self.has_activation = hasattr(config, 'summary_activation') and config.summary_activation == 'tanh' + if self.has_activation: self.activation = tf.keras.activations.tanh - self.first_dropout = None - if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0: + self.has_first_dropout = hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0 + if self.has_first_dropout: self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout) - self.last_dropout = None - if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0: + self.has_last_dropout = hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0 + if self.has_last_dropout: self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) def call(self, inputs, training=False): @@ -456,17 +456,17 @@ class TFSequenceSummary(tf.keras.layers.Layer): elif self.summary_type == 'attn': raise NotImplementedError - if training and self.first_dropout is not None: - output = self.first_dropout(output) + if self.has_first_dropout: + output = self.first_dropout(output, training=training) - if self.summary is not None: + if self.has_summary: output = self.summary(output) - if self.activation is not None: + if self.has_activation: output = self.activation(output) - if training and self.last_dropout is not None: - output = self.last_dropout(output) + if self.has_last_dropout: + output = self.last_dropout(output, training=training) return output