Allow RAG to output decoder cross-attentions (#9789)

* get cross attns

* add cross-attns doc strings

* fix typo

* line length

* Apply suggestions from code review

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>

Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
Derrick Blakely
2021-01-26 12:32:46 -05:00
committed by GitHub
parent 8f6c12d306
commit 8edc98bb70

View File

@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads. average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads. average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
""" """
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
class RagPreTrainedModel(PreTrainedModel): class RagPreTrainedModel(PreTrainedModel):
@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
return_dict=True, return_dict=True,
) )
@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel):
generator_enc_attentions=gen_outputs.encoder_attentions, generator_enc_attentions=gen_outputs.encoder_attentions,
generator_dec_hidden_states=gen_outputs.decoder_hidden_states, generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
generator_dec_attentions=gen_outputs.decoder_attentions, generator_dec_attentions=gen_outputs.decoder_attentions,
generator_cross_attentions=gen_outputs.cross_attentions,
) )
@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions, generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions, generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
) )
@property @property
@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions, generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions, generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
) )
@torch.no_grad() @torch.no_grad()