[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)
* fix rag * fix slow test * fix past in bart
This commit is contained in:
committed by
GitHub
parent
6587cf9f84
commit
fa1ddced9e
@@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
n_docs=None,
|
||||
**kwargs
|
||||
):
|
||||
if past is not None:
|
||||
# if past is defined use only last decoder_input_ids
|
||||
decoder_input_ids = decoder_input_ids[:, -1:]
|
||||
|
||||
return {
|
||||
"input_ids": None,
|
||||
"encoder_outputs": encoder_outputs,
|
||||
@@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
def _reorder_cache(past, beam_idx):
|
||||
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
|
||||
|
||||
def _reorder_stacked(hidden_states):
|
||||
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
|
||||
def _reorder_stacked(hidden_states, new_order):
|
||||
n_docs = hidden_states.shape[0] // new_order.shape[0]
|
||||
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
|
||||
hidden_states = hidden_states.index_select(0, beam_idx)
|
||||
return hidden_states.view(-1, *hidden_states.shape[2:])
|
||||
hidden_states = hidden_states.index_select(0, new_order)
|
||||
result = hidden_states.view(-1, *hidden_states.shape[2:])
|
||||
return result
|
||||
|
||||
def _reorder_buffer(attn_cache):
|
||||
for k, input_buffer_k in attn_cache.items():
|
||||
if input_buffer_k is not None:
|
||||
attn_cache[k] = _reorder_stacked(input_buffer_k)
|
||||
return attn_cache
|
||||
|
||||
reordered_past = []
|
||||
reordered_past = ()
|
||||
for layer_past in past:
|
||||
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
|
||||
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()}
|
||||
reordered_past.append(layer_past_new)
|
||||
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
|
||||
|
||||
return reordered_past
|
||||
|
||||
|
||||
Reference in New Issue
Block a user