Allow RAG to output decoder cross-attentions (#9789)
* get cross attns * add cross-attns doc strings * fix typo * line length * Apply suggestions from code review Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
||||
average in the self-attention heads.
|
||||
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
loss: Optional[torch.FloatTensor] = None
|
||||
@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
|
||||
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):
|
||||
|
||||
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
||||
average in the self-attention heads.
|
||||
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||
sequence_length, sequence_length)`.
|
||||
|
||||
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
|
||||
weighted average in the cross-attention heads.
|
||||
"""
|
||||
|
||||
logits: torch.FloatTensor = None
|
||||
@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
|
||||
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||
|
||||
|
||||
class RagPreTrainedModel(PreTrainedModel):
|
||||
@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel):
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel):
|
||||
generator_enc_attentions=gen_outputs.encoder_attentions,
|
||||
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
|
||||
generator_dec_attentions=gen_outputs.decoder_attentions,
|
||||
generator_cross_attentions=gen_outputs.cross_attentions,
|
||||
)
|
||||
|
||||
|
||||
@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
||||
generator_enc_attentions=outputs.generator_enc_attentions,
|
||||
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
||||
generator_dec_attentions=outputs.generator_dec_attentions,
|
||||
generator_cross_attentions=outputs.generator_cross_attentions,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
generator_enc_attentions=outputs.generator_enc_attentions,
|
||||
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
||||
generator_dec_attentions=outputs.generator_dec_attentions,
|
||||
generator_cross_attentions=outputs.generator_cross_attentions,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user