From 76c0bc06d549e0ecf746fdd3d72eb235a9c4aec2 Mon Sep 17 00:00:00 2001 From: Rostislav Nedelchev Date: Sat, 30 Nov 2019 21:01:04 +0100 Subject: [PATCH] [XLNet] Changed post-processing of attention w.r.t to target_mapping Whenever target_mapping is provided to the input, XLNet outputs two different attention streams. Based on that the attention output would be on of the two: - a list of tensors (usual case for most transformers) - a list of 2-tuples of tensors, one tesor for each of attention streams Docs and unit-tests have been updated --- transformers/modeling_xlnet.py | 24 ++++++++++++++++------- transformers/tests/modeling_xlnet_test.py | 18 +++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/transformers/modeling_xlnet.py b/transformers/modeling_xlnet.py index 476d9ab13d..56d755c11b 100644 --- a/transformers/modeling_xlnet.py +++ b/transformers/modeling_xlnet.py @@ -581,8 +581,9 @@ class XLNetModel(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: @@ -878,7 +879,11 @@ class XLNetModel(XLNetPreTrainedModel): hidden_states = tuple(hs.permute(1, 0, 2).contiguous() for hs in hidden_states) outputs = outputs + (hidden_states,) if self.output_attentions: - attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions) + if target_mapping is not None: + # when target_mapping is provided, there are 2-tuple of attentions + attentions = tuple(tuple(att_stream.permute(2, 3, 0, 1).contiguous() for att_stream in t) for t in attentions) + else: + attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) outputs = outputs + (attentions,) return outputs # outputs, (new_mems), (hidden_states), (attentions) @@ -911,8 +916,9 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: @@ -993,8 +999,9 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: @@ -1093,8 +1100,9 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: @@ -1178,8 +1186,9 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: @@ -1292,8 +1301,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): of shape ``(batch_size, sequence_length, hidden_size)``: Hidden-states of the model at the output of each layer plus the initial embedding outputs. **attentions**: (`optional`, returned when ``config.output_attentions=True``) - list of 2-tuple of ``torch.FloatTensor`` (one for each layer, one for each attention stream) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + When ``target_mapping is not None``, the attentions outputs are a list of 2-tuple of ``torch.FloatTensor``. Examples:: diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index d97ea6a425..a5ee9b1e0e 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -163,6 +163,18 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): list(list(mem.size()) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers) + def create_and_check_xlnet_base_model_with_att_output(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, + target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): + model = XLNetModel(config) + model.eval() + + _, _, attentions = model(input_ids_1, target_mapping=target_mapping) + + self.parent.assertEqual(len(attentions), config.n_layer) + self.parent.assertIsInstance(attentions[0], tuple) + self.parent.assertEqual(len(attentions[0]), 2) + self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape) + def create_and_check_xlnet_lm_head(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): model = XLNetLMHeadModel(config) @@ -306,6 +318,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_xlnet_base_model(*config_and_inputs) + def test_xlnet_base_model_with_att_output(self): + self.model_tester.set_seed() + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config_and_inputs[0].output_attentions = True + self.model_tester.create_and_check_xlnet_base_model_with_att_output(*config_and_inputs) + def test_xlnet_lm_head(self): self.model_tester.set_seed() config_and_inputs = self.model_tester.prepare_config_and_inputs()