Allow RAG to output decoder cross-attentions (#9789)
* get cross attns * add cross-attns doc strings * fix typo * line length * Apply suggestions from code review Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
This commit is contained in:
@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
||||||
average in the self-attention heads.
|
average in the self-attention heads.
|
||||||
|
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
loss: Optional[torch.FloatTensor] = None
|
loss: Optional[torch.FloatTensor] = None
|
||||||
@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
|
|||||||
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):
|
|||||||
|
|
||||||
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
|
||||||
average in the self-attention heads.
|
average in the self-attention heads.
|
||||||
|
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
||||||
|
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
||||||
|
sequence_length, sequence_length)`.
|
||||||
|
|
||||||
|
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
|
||||||
|
weighted average in the cross-attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logits: torch.FloatTensor = None
|
logits: torch.FloatTensor = None
|
||||||
@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
|
|||||||
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
||||||
|
|
||||||
|
|
||||||
class RagPreTrainedModel(PreTrainedModel):
|
class RagPreTrainedModel(PreTrainedModel):
|
||||||
@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel):
|
|||||||
generator_enc_attentions=gen_outputs.encoder_attentions,
|
generator_enc_attentions=gen_outputs.encoder_attentions,
|
||||||
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
|
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
|
||||||
generator_dec_attentions=gen_outputs.decoder_attentions,
|
generator_dec_attentions=gen_outputs.decoder_attentions,
|
||||||
|
generator_cross_attentions=gen_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
|
|||||||
generator_enc_attentions=outputs.generator_enc_attentions,
|
generator_enc_attentions=outputs.generator_enc_attentions,
|
||||||
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
||||||
generator_dec_attentions=outputs.generator_dec_attentions,
|
generator_dec_attentions=outputs.generator_dec_attentions,
|
||||||
|
generator_cross_attentions=outputs.generator_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
generator_enc_attentions=outputs.generator_enc_attentions,
|
generator_enc_attentions=outputs.generator_enc_attentions,
|
||||||
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
|
||||||
generator_dec_attentions=outputs.generator_dec_attentions,
|
generator_dec_attentions=outputs.generator_dec_attentions,
|
||||||
|
generator_cross_attentions=outputs.generator_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user