[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)

* fix rag

* fix slow test

* fix past in bart
This commit is contained in:
Patrick von Platen
2020-12-14 12:32:26 +01:00
committed by GitHub
parent 6587cf9f84
commit fa1ddced9e
3 changed files with 33 additions and 39 deletions

View File

@@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs=None,
**kwargs
):
if past is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
return {
"input_ids": None,
"encoder_outputs": encoder_outputs,
@@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states):
n_docs = hidden_states.shape[0] // beam_idx.shape[0]
def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
hidden_states = hidden_states.index_select(0, beam_idx)
return hidden_states.view(-1, *hidden_states.shape[2:])
hidden_states = hidden_states.index_select(0, new_order)
result = hidden_states.view(-1, *hidden_states.shape[2:])
return result
def _reorder_buffer(attn_cache):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = _reorder_stacked(input_buffer_k)
return attn_cache
reordered_past = []
reordered_past = ()
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()}
reordered_past.append(layer_past_new)
reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
return reordered_past