Add caching mechanism to BERT, RoBERTa (#9183)

* add past_key_values

* add use_cache option

* make mask before cutting ids

* adjust position_ids according to past_key_values

* flatten past_key_values

* fix positional embeds

* fix _reorder_cache

* set use_cache to false when not decoder, fix attention mask init

* add test for caching

* add past_key_values for Roberta

* fix position embeds

* add caching test for roberta

* add doc

* make style

* doc, fix attention mask, test

* small fixes

* adress patrick's comments

* input_ids shouldn't start with pad token

* use_cache only when decoder

* make consistent with bert

* make copies consistent

* add use_cache to encoder

* add past_key_values to tapas attention

* apply suggestions from code review

* make coppies consistent

* add attn mask in tests

* remove copied from longformer

* apply suggestions from code review

* fix bart test

* nit

* simplify model outputs

* fix doc

* fix output ordering
This commit is contained in:
Suraj Patil
2020-12-23 23:01:32 +05:30
committed by GitHub
parent a1cb6e9866
commit 88ef8893cd
17 changed files with 809 additions and 166 deletions

View File

@@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None, # TODO: (PVP) implement :obj:`use_cache`
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None, # TODO: (PVP) implement :obj:`use_cache`
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
@@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel):
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
# TODO(PVP): currently it is not possible to use `past`
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
@@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"use_cache": use_cache,
}
# Ideally all models should have a :obj:`use_cache`
# leave following to ifs until all have it implemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
return input_dict
def _reorder_cache(self, past, beam_idx):