Add caching mechanism to BERT, RoBERTa (#9183)
* add past_key_values * add use_cache option * make mask before cutting ids * adjust position_ids according to past_key_values * flatten past_key_values * fix positional embeds * fix _reorder_cache * set use_cache to false when not decoder, fix attention mask init * add test for caching * add past_key_values for Roberta * fix position embeds * add caching test for roberta * add doc * make style * doc, fix attention mask, test * small fixes * adress patrick's comments * input_ids shouldn't start with pad token * use_cache only when decoder * make consistent with bert * make copies consistent * add use_cache to encoder * add past_key_values to tapas attention * apply suggestions from code review * make coppies consistent * add attn mask in tests * remove copied from longformer * apply suggestions from code review * fix bart test * nit * simplify model outputs * fix doc * fix output ordering
This commit is contained in:
@@ -345,11 +345,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
decoder_input_ids=None,
|
||||
decoder_attention_mask=None,
|
||||
encoder_outputs=None,
|
||||
past_key_values=None, # TODO: (PVP) implement :obj:`use_cache`
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
decoder_inputs_embeds=None,
|
||||
labels=None,
|
||||
use_cache=None, # TODO: (PVP) implement :obj:`use_cache`
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -413,18 +413,19 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
labels=labels,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
use_cache=use_cache,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=return_dict,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
# TODO(PVP): currently it is not possible to use `past`
|
||||
if not return_dict:
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
return Seq2SeqLMOutput(
|
||||
loss=decoder_outputs.loss,
|
||||
logits=decoder_outputs.logits,
|
||||
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
|
||||
past_key_values=decoder_outputs.past_key_values,
|
||||
decoder_hidden_states=decoder_outputs.hidden_states,
|
||||
decoder_attentions=decoder_outputs.attentions,
|
||||
cross_attentions=decoder_outputs.cross_attentions,
|
||||
@@ -433,24 +434,19 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, encoder_outputs=None, **kwargs):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||
def prepare_inputs_for_generation(
|
||||
self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
|
||||
):
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past)
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
"encoder_outputs": encoder_outputs,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
# Ideally all models should have a :obj:`use_cache`
|
||||
# leave following to ifs until all have it implemented
|
||||
if "use_cache" in decoder_inputs:
|
||||
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
||||
|
||||
if "past_key_values" in decoder_inputs:
|
||||
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
|
||||
|
||||
return input_dict
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
|
||||
Reference in New Issue
Block a user