From ad39271ae8518afaa4c4831bd1ddccc3bc86eab6 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 25 Sep 2020 12:20:39 -0400 Subject: [PATCH] Fix FP16 and attention masks in FunnelTransformer (#7374) * Fix #7371 * Fix training * Fix test values * Apply the fix to TF as well --- src/transformers/modeling_funnel.py | 7 ++++--- src/transformers/modeling_tf_funnel.py | 2 +- tests/test_modeling_funnel.py | 10 +++++----- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_funnel.py b/src/transformers/modeling_funnel.py index 8c8810d507..0ef5d596ea 100644 --- a/src/transformers/modeling_funnel.py +++ b/src/transformers/modeling_funnel.py @@ -367,7 +367,6 @@ class FunnelAttentionStructure(nn.Module): # Stride is applied on the second-to-last dimension. stride = (stride, 1) - tensor = tensor.float() if mode == "mean": tensor = F.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True) elif mode == "max": @@ -554,7 +553,7 @@ class FunnelRelMultiheadAttention(nn.Module): attn_score = attn_score.float() # perform masking if attention_mask is not None: - attn_score = attn_score - INF * attention_mask[:, None, None].float() + attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float()) # attention probability attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype) attn_prob = self.attention_dropout(attn_prob) @@ -856,7 +855,9 @@ FUNNEL_INPUTS_DOCSTRING = r""" attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **maked**. `What are attention masks? <../glossary.html#attention-mask>`__ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`): diff --git a/src/transformers/modeling_tf_funnel.py b/src/transformers/modeling_tf_funnel.py index 3d35d2c4af..2137b00183 100644 --- a/src/transformers/modeling_tf_funnel.py +++ b/src/transformers/modeling_tf_funnel.py @@ -555,7 +555,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer): attn_score = tf.cast(attn_score, tf.float32) # perform masking if attention_mask is not None: - attn_score = attn_score - INF * tf.cast(attention_mask[:, None, None], tf.float32) + attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32)) # attention probability attn_prob = tf.nn.softmax(attn_score, axis=-1) if dtype != tf.float32: diff --git a/tests/test_modeling_funnel.py b/tests/test_modeling_funnel.py index 0197a8fec0..6d12647857 100644 --- a/tests/test_modeling_funnel.py +++ b/tests/test_modeling_funnel.py @@ -428,16 +428,16 @@ class FunnelModelIntegrationTest(unittest.TestCase): model = FunnelModel.from_pretrained("sgugger/funnel-random-tiny") output = model(input_ids, token_type_ids=token_type_ids)[0].abs() - expected_output_sum = torch.tensor(2344.9023) - expected_output_mean = torch.tensor(0.8053) + expected_output_sum = torch.tensor(2344.8352) + expected_output_mean = torch.tensor(0.8052) self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4)) self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4)) attention_mask = torch.tensor([[1] * 7, [1] * 4 + [0] * 3] * 6 + [[0, 1, 1, 0, 0, 1, 1]]) output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0].abs() - expected_output_sum = torch.tensor(2363.2178) - expected_output_mean = torch.tensor(0.8115) + expected_output_sum = torch.tensor(2343.8425) + expected_output_mean = torch.tensor(0.8049) self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4)) self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4)) @@ -448,7 +448,7 @@ class FunnelModelIntegrationTest(unittest.TestCase): inputs = tokenizer("Hello! I am the Funnel Transformer model.", return_tensors="pt") output = model(**inputs)[0] - expected_output_sum = torch.tensor(235.7827) + expected_output_sum = torch.tensor(235.7246) expected_output_mean = torch.tensor(0.0256) self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4)) self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))