Fix FP16 and attention masks in FunnelTransformer (#7374)
* Fix #7371 * Fix training * Fix test values * Apply the fix to TF as well
This commit is contained in:
@@ -367,7 +367,6 @@ class FunnelAttentionStructure(nn.Module):
|
|||||||
# Stride is applied on the second-to-last dimension.
|
# Stride is applied on the second-to-last dimension.
|
||||||
stride = (stride, 1)
|
stride = (stride, 1)
|
||||||
|
|
||||||
tensor = tensor.float()
|
|
||||||
if mode == "mean":
|
if mode == "mean":
|
||||||
tensor = F.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
|
tensor = F.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
|
||||||
elif mode == "max":
|
elif mode == "max":
|
||||||
@@ -554,7 +553,7 @@ class FunnelRelMultiheadAttention(nn.Module):
|
|||||||
attn_score = attn_score.float()
|
attn_score = attn_score.float()
|
||||||
# perform masking
|
# perform masking
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_score = attn_score - INF * attention_mask[:, None, None].float()
|
attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float())
|
||||||
# attention probability
|
# attention probability
|
||||||
attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
|
attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
|
||||||
attn_prob = self.attention_dropout(attn_prob)
|
attn_prob = self.attention_dropout(attn_prob)
|
||||||
@@ -856,7 +855,9 @@ FUNNEL_INPUTS_DOCSTRING = r"""
|
|||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
||||||
Mask to avoid performing attention on padding token indices.
|
Mask to avoid performing attention on padding token indices.
|
||||||
Mask values selected in ``[0, 1]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **maked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
||||||
|
|||||||
@@ -555,7 +555,7 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
|
|||||||
attn_score = tf.cast(attn_score, tf.float32)
|
attn_score = tf.cast(attn_score, tf.float32)
|
||||||
# perform masking
|
# perform masking
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_score = attn_score - INF * tf.cast(attention_mask[:, None, None], tf.float32)
|
attn_score = attn_score - INF * (1 - tf.cast(attention_mask[:, None, None], tf.float32))
|
||||||
# attention probability
|
# attention probability
|
||||||
attn_prob = tf.nn.softmax(attn_score, axis=-1)
|
attn_prob = tf.nn.softmax(attn_score, axis=-1)
|
||||||
if dtype != tf.float32:
|
if dtype != tf.float32:
|
||||||
|
|||||||
@@ -428,16 +428,16 @@ class FunnelModelIntegrationTest(unittest.TestCase):
|
|||||||
model = FunnelModel.from_pretrained("sgugger/funnel-random-tiny")
|
model = FunnelModel.from_pretrained("sgugger/funnel-random-tiny")
|
||||||
output = model(input_ids, token_type_ids=token_type_ids)[0].abs()
|
output = model(input_ids, token_type_ids=token_type_ids)[0].abs()
|
||||||
|
|
||||||
expected_output_sum = torch.tensor(2344.9023)
|
expected_output_sum = torch.tensor(2344.8352)
|
||||||
expected_output_mean = torch.tensor(0.8053)
|
expected_output_mean = torch.tensor(0.8052)
|
||||||
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
||||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||||
|
|
||||||
attention_mask = torch.tensor([[1] * 7, [1] * 4 + [0] * 3] * 6 + [[0, 1, 1, 0, 0, 1, 1]])
|
attention_mask = torch.tensor([[1] * 7, [1] * 4 + [0] * 3] * 6 + [[0, 1, 1, 0, 0, 1, 1]])
|
||||||
output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0].abs()
|
output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0].abs()
|
||||||
|
|
||||||
expected_output_sum = torch.tensor(2363.2178)
|
expected_output_sum = torch.tensor(2343.8425)
|
||||||
expected_output_mean = torch.tensor(0.8115)
|
expected_output_mean = torch.tensor(0.8049)
|
||||||
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
||||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||||
|
|
||||||
@@ -448,7 +448,7 @@ class FunnelModelIntegrationTest(unittest.TestCase):
|
|||||||
inputs = tokenizer("Hello! I am the Funnel Transformer model.", return_tensors="pt")
|
inputs = tokenizer("Hello! I am the Funnel Transformer model.", return_tensors="pt")
|
||||||
output = model(**inputs)[0]
|
output = model(**inputs)[0]
|
||||||
|
|
||||||
expected_output_sum = torch.tensor(235.7827)
|
expected_output_sum = torch.tensor(235.7246)
|
||||||
expected_output_mean = torch.tensor(0.0256)
|
expected_output_mean = torch.tensor(0.0256)
|
||||||
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
self.assertTrue(torch.allclose(output.sum(), expected_output_sum, atol=1e-4))
|
||||||
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
self.assertTrue(torch.allclose(output.mean(), expected_output_mean, atol=1e-4))
|
||||||
|
|||||||
Reference in New Issue
Block a user