From 1750e629006bb6989aef5b4e141f3477f891a098 Mon Sep 17 00:00:00 2001 From: Mehrad Moradshahi Date: Wed, 3 Mar 2021 00:27:02 -0800 Subject: [PATCH] Generate can return cross-attention weights too (#10493) --- src/transformers/generation_utils.py | 42 ++++++++++++++++++++++++++-- tests/test_generation_utils.py | 10 ++++--- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 1a75718b1d..3a2d56d87c 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. + cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. @@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length, sequence_length)`. + cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`. @@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, generated_length, sequence_length)`. + cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, @@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, sequence_length)`. + cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. @@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None @@ -1177,6 +1193,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states @@ -1212,6 +1229,8 @@ class GenerationMixin: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( @@ -1260,6 +1279,7 @@ class GenerationMixin: encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: @@ -1384,6 +1404,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states @@ -1424,6 +1445,8 @@ class GenerationMixin: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( @@ -1468,6 +1491,7 @@ class GenerationMixin: encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: @@ -1604,6 +1628,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states @@ -1656,6 +1681,8 @@ class GenerationMixin: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( @@ -1716,6 +1743,7 @@ class GenerationMixin: encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: @@ -1865,6 +1893,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states @@ -1913,6 +1942,8 @@ class GenerationMixin: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( @@ -1968,17 +1999,18 @@ class GenerationMixin: if not output_scores: sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: - return BeamSearchEncoderDecoderOutput( + return BeamSampleEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: - return BeamSearchDecoderOnlyOutput( + return BeamSampleDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, @@ -2115,6 +2147,7 @@ class GenerationMixin: # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states @@ -2238,6 +2271,8 @@ class GenerationMixin: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( @@ -2263,7 +2298,7 @@ class GenerationMixin: if return_dict_in_generate: if not output_scores: - sequence_outputs["sequence_scores"] + sequence_outputs["sequence_scores"] = None if self.config.is_encoder_decoder: return BeamSearchEncoderDecoderOutput( sequences=sequence_outputs["sequences"], @@ -2272,6 +2307,7 @@ class GenerationMixin: encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) else: diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index d1f01a7ae1..2c96693069 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -39,6 +39,8 @@ if is_torch_available(): TopPLogitsWarper, ) from transformers.generation_utils import ( + BeamSampleDecoderOnlyOutput, + BeamSampleEncoderDecoderOutput, BeamSearchDecoderOnlyOutput, BeamSearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput, @@ -900,11 +902,11 @@ class GenerationTesterMixin: ) if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput) - self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) + self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) + self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) else: - self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput) - self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) + self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) + self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) self.assertTrue(