[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)
* fix rag * fix slow test * fix past in bart
This commit is contained in:
committed by
GitHub
parent
6587cf9f84
commit
fa1ddced9e
@@ -16,7 +16,7 @@
|
|||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
encoder_attn_mask: Optional[torch.Tensor] = None,
|
encoder_attn_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: Optional[torch.Tensor] = False,
|
output_attentions: Optional[torch.Tensor] = False,
|
||||||
):
|
):
|
||||||
@@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
|
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at first position
|
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
||||||
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None
|
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
||||||
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn(
|
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=self_attn_past_key_value,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
@@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module):
|
|||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at second position
|
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
||||||
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None
|
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
||||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
@@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
|
# add cross-attn to positions 3,4 of present_key_value tuple
|
||||||
|
present_key_value = present_key_value + cross_attn_present_key_value
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
@@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module):
|
|||||||
if not self.normalize_before:
|
if not self.normalize_before:
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
|
|
||||||
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
|
|
||||||
present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self_attn_weights,
|
self_attn_weights,
|
||||||
@@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
|
||||||
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
||||||
cross-attention of the decoder.
|
cross-attention of the decoder.
|
||||||
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
|
||||||
|
|
||||||
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
||||||
@@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
- 0 for tokens that are **masked**.
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
|
||||||
decoding.
|
decoding.
|
||||||
|
|
||||||
@@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel):
|
|||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
# past_key_values_length
|
# past_key_values_length
|
||||||
past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0
|
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||||
@@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict:
|
|
||||||
return tuple(past_state.index_select(0, new_order) for past_state in cache)
|
|
||||||
|
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
for layer_past in past:
|
for layer_past in past:
|
||||||
reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),)
|
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
n_docs=None,
|
n_docs=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
if past is not None:
|
||||||
|
# if past is defined use only last decoder_input_ids
|
||||||
|
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": None,
|
"input_ids": None,
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
@@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
|||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
||||||
|
|
||||||
def _reorder_stacked(hidden_states):
|
def _reorder_stacked(hidden_states, new_order):
|
||||||
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
|
n_docs = hidden_states.shape[0] // new_order.shape[0]
|
||||||
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
|
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
|
||||||
hidden_states = hidden_states.index_select(0, beam_idx)
|
hidden_states = hidden_states.index_select(0, new_order)
|
||||||
return hidden_states.view(-1, *hidden_states.shape[2:])
|
result = hidden_states.view(-1, *hidden_states.shape[2:])
|
||||||
|
return result
|
||||||
|
|
||||||
def _reorder_buffer(attn_cache):
|
reordered_past = ()
|
||||||
for k, input_buffer_k in attn_cache.items():
|
|
||||||
if input_buffer_k is not None:
|
|
||||||
attn_cache[k] = _reorder_stacked(input_buffer_k)
|
|
||||||
return attn_cache
|
|
||||||
|
|
||||||
reordered_past = []
|
|
||||||
for layer_past in past:
|
for layer_past in past:
|
||||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||||
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()}
|
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
||||||
reordered_past.append(layer_past_new)
|
|
||||||
|
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|||||||
@@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
|
|||||||
n_docs=self.n_docs,
|
n_docs=self.n_docs,
|
||||||
retrieval_vector_size=self.retrieval_vector_size,
|
retrieval_vector_size=self.retrieval_vector_size,
|
||||||
max_combined_length=self.max_combined_length,
|
max_combined_length=self.max_combined_length,
|
||||||
use_cache=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
|
|||||||
n_docs=self.n_docs,
|
n_docs=self.n_docs,
|
||||||
retrieval_vector_size=self.retrieval_vector_size,
|
retrieval_vector_size=self.retrieval_vector_size,
|
||||||
max_combined_length=self.max_combined_length,
|
max_combined_length=self.max_combined_length,
|
||||||
use_cache=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
generator_tokenizer=rag_decoder_tokenizer,
|
generator_tokenizer=rag_decoder_tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
rag_token = self.sequence_model
|
rag_sequence = self.sequence_model
|
||||||
rag_token.set_retriever(rag_retriever)
|
rag_sequence.set_retriever(rag_retriever)
|
||||||
|
|
||||||
input_ids = rag_question_encoder_tokenizer(
|
input_ids = rag_question_encoder_tokenizer(
|
||||||
"who sings does he love me with reba", return_tensors="pt"
|
"who sings does he love me with reba", return_tensors="pt"
|
||||||
@@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
|
|
||||||
input_ids = input_ids.to(torch_device)
|
input_ids = input_ids.to(torch_device)
|
||||||
|
|
||||||
output_ids = rag_token.generate(
|
output_ids = rag_sequence.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id,
|
decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
|
||||||
num_beams=2,
|
num_beams=2,
|
||||||
num_return_sequences=2,
|
num_return_sequences=2,
|
||||||
)
|
)
|
||||||
@@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
retriever = RagRetriever.from_pretrained(
|
retriever = RagRetriever.from_pretrained(
|
||||||
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
|
||||||
)
|
)
|
||||||
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
|
||||||
torch_device
|
torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
|
|||||||
" walls of the abdomen",
|
" walls of the abdomen",
|
||||||
" spodumene",
|
" spodumene",
|
||||||
" obama",
|
" obama",
|
||||||
" grainger's compound",
|
" new orleans",
|
||||||
" japan",
|
" japan",
|
||||||
" old trafford stadium",
|
" old trafford",
|
||||||
]
|
]
|
||||||
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
self.assertListEqual(outputs, EXPECTED_OUTPUTS)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user