diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index 899fd09c67..c3b95c6b77 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -348,15 +348,22 @@ class TFMPNetEncoder(tf.keras.layers.Layer): self.n_heads = config.num_attention_heads self.output_attentions = config.output_attentions self.output_hidden_states = config.output_hidden_states + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.initializer_range = config.initializer_range self.layer = [TFMPNetLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)] - self.relative_attention_bias = tf.keras.layers.Embedding( - config.relative_attention_num_buckets, - self.n_heads, - name="relative_attention_bias", - ) self.relative_attention_num_buckets = config.relative_attention_num_buckets + def build(self, input_shape): + with tf.name_scope("relative_attention_bias"): + self.relative_attention_bias = self.add_weight( + name="embeddings", + shape=[self.relative_attention_num_buckets, self.n_heads], + initializer=get_initializer(self.initializer_range), + ) + + return super().build(input_shape) + def call( self, hidden_states, @@ -405,18 +412,16 @@ class TFMPNetEncoder(tf.keras.layers.Layer): n = -relative_position num_buckets //= 2 - ret += tf.dtypes.cast(tf.math.less(n, 0), tf.int32) * num_buckets + ret += tf.cast(tf.math.less(n, 0), dtype=relative_position.dtype) * num_buckets n = tf.math.abs(n) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = tf.math.less(n, max_exact) - val_if_large = max_exact + tf.dtypes.cast( - tf.math.log(tf.dtypes.cast(n, tf.float32) / max_exact) - / math.log(max_distance / max_exact) - * (num_buckets - max_exact), - tf.int32, + val_if_large = max_exact + tf.cast( + tf.math.log(n / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact), + dtype=relative_position.dtype, ) val_if_large = tf.math.minimum(val_if_large, num_buckets - 1) @@ -441,7 +446,7 @@ class TFMPNetEncoder(tf.keras.layers.Layer): relative_position, num_buckets=self.relative_attention_num_buckets, ) - values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads) + values = tf.gather(self.relative_attention_bias, rp_bucket) # shape (qlen, klen, num_heads) values = tf.expand_dims(tf.transpose(values, [2, 0, 1]), axis=0) # shape (1, num_heads, qlen, klen) return values @@ -541,7 +546,9 @@ class TFMPNetMainLayer(tf.keras.layers.Layer): # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype) - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head diff --git a/tests/test_modeling_tf_mpnet.py b/tests/test_modeling_tf_mpnet.py index d67d68f5d3..c0305dede9 100644 --- a/tests/test_modeling_tf_mpnet.py +++ b/tests/test_modeling_tf_mpnet.py @@ -232,10 +232,6 @@ class TFMPNetModelTest(TFModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mpnet_for_token_classification(*config_and_inputs) - def test_xla_mode(self): - # TODO JP: Make MPNet XLA compliant - pass - @slow def test_model_from_pretrained(self): for model_name in ["microsoft/mpnet-base"]: