Generate can return cross-attention weights too (#10493)

This commit is contained in:
Mehrad Moradshahi
2021-03-03 00:27:02 -08:00
committed by GitHub
parent b013842244
commit 1750e62900
2 changed files with 45 additions and 7 deletions

View File

@@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
@@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
@@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
@@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
@@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
@@ -1177,6 +1193,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1212,6 +1229,8 @@ class GenerationMixin:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
@@ -1260,6 +1279,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
@@ -1384,6 +1404,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1424,6 +1445,8 @@ class GenerationMixin:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
@@ -1468,6 +1491,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
@@ -1604,6 +1628,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1656,6 +1681,8 @@ class GenerationMixin:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
@@ -1716,6 +1743,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
@@ -1865,6 +1893,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -1913,6 +1942,8 @@ class GenerationMixin:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
@@ -1968,17 +1999,18 @@ class GenerationMixin:
if not output_scores:
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return BeamSearchDecoderOnlyOutput(
return BeamSampleDecoderOnlyOutput(
sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"],
scores=scores,
@@ -2115,6 +2147,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
@@ -2238,6 +2271,8 @@ class GenerationMixin:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
@@ -2263,7 +2298,7 @@ class GenerationMixin:
if return_dict_in_generate:
if not output_scores:
sequence_outputs["sequence_scores"]
sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"],
@@ -2272,6 +2307,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else: