Generate can return cross-attention weights too (#10493)

This commit is contained in:
Mehrad Moradshahi
2021-03-03 00:27:02 -08:00
committed by GitHub
parent b013842244
commit 1750e62900
2 changed files with 45 additions and 7 deletions

View File

@@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
@@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
sequence_length)`. sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
@@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
generated_length, sequence_length)`. generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
@@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
sequence_length)`. sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
@@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -1177,6 +1193,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1212,6 +1229,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
@@ -1260,6 +1279,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
@@ -1384,6 +1404,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1424,6 +1445,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
@@ -1468,6 +1491,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
@@ -1604,6 +1628,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1656,6 +1681,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
@@ -1716,6 +1743,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
@@ -1865,6 +1893,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1913,6 +1942,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
@@ -1968,17 +1999,18 @@ class GenerationMixin:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
return BeamSearchDecoderOnlyOutput( return BeamSampleDecoderOnlyOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
@@ -2115,6 +2147,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -2238,6 +2271,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
@@ -2263,7 +2298,7 @@ class GenerationMixin:
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
@@ -2272,6 +2307,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:

View File

@@ -39,6 +39,8 @@ if is_torch_available():
TopPLogitsWarper, TopPLogitsWarper,
) )
from transformers.generation_utils import ( from transformers.generation_utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput, BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput, BeamSearchEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
@@ -900,11 +902,11 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
self.assertTrue( self.assertTrue(