Merge pull request #2007 from roskoN/xlnet_attention_fix
fixed XLNet attention output for both attention streams whenever target_mapping is provided
This commit is contained in:
@@ -583,6 +583,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -878,6 +879,10 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states)
|
||||||
outputs = outputs + (hidden_states,)
|
outputs = outputs + (hidden_states,)
|
||||||
if self.output_attentions:
|
if self.output_attentions:
|
||||||
|
if target_mapping is not None:
|
||||||
|
# when target_mapping is provided, there are 2-tuple of attentions
|
||||||
|
attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions)
|
||||||
|
else:
|
||||||
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||||
outputs = outputs + (attentions,)
|
outputs = outputs + (attentions,)
|
||||||
|
|
||||||
@@ -913,6 +918,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -995,6 +1001,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -1195,6 +1202,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -1280,6 +1288,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -1394,6 +1403,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
|
|||||||
@@ -166,6 +166,18 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
list(list(mem.size()) for mem in result["mems_1"]),
|
list(list(mem.size()) for mem in result["mems_1"]),
|
||||||
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
[[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
|
||||||
|
|
||||||
|
def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
|
||||||
|
model = XLNetModel(config)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
_, _, attentions = model(input_ids_1, target_mapping=target_mapping)
|
||||||
|
|
||||||
|
self.parent.assertEqual(len(attentions), config.n_layer)
|
||||||
|
self.parent.assertIsInstance(attentions[0], tuple)
|
||||||
|
self.parent.assertEqual(len(attentions[0]), 2)
|
||||||
|
self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
|
||||||
|
|
||||||
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
|
||||||
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels):
|
||||||
model = XLNetLMHeadModel(config)
|
model = XLNetLMHeadModel(config)
|
||||||
@@ -341,6 +353,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_xlnet_base_model_with_att_output(self):
|
||||||
|
self.model_tester.set_seed()
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
config_and_inputs[0].output_attentions = True
|
||||||
|
self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs)
|
||||||
|
|
||||||
def test_xlnet_lm_head(self):
|
def test_xlnet_lm_head(self):
|
||||||
self.model_tester.set_seed()
|
self.model_tester.set_seed()
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user