Add caching mechanism to BERT, RoBERTa (#9183)

* add past_key_values

* add use_cache option

* make mask before cutting ids

* adjust position_ids according to past_key_values

* flatten past_key_values

* fix positional embeds

* fix _reorder_cache

* set use_cache to false when not decoder, fix attention mask init

* add test for caching

* add past_key_values for Roberta

* fix position embeds

* add caching test for roberta

* add doc

* make style

* doc, fix attention mask, test

* small fixes

* adress patrick's comments

* input_ids shouldn't start with pad token

* use_cache only when decoder

* make consistent with bert

* make copies consistent

* add use_cache to encoder

* add past_key_values to tapas attention

* apply suggestions from code review

* make coppies consistent

* add attn mask in tests

* remove copied from longformer

* apply suggestions from code review

* fix bart test

* nit

* simplify model outputs

* fix doc

* fix output ordering
This commit is contained in:
Suraj Patil
2020-12-23 23:01:32 +05:30
committed by GitHub
parent a1cb6e9866
commit 88ef8893cd
17 changed files with 809 additions and 166 deletions

View File

@@ -25,6 +25,8 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if is_torch_available():
import torch
from transformers import BertGenerationConfig, BertGenerationDecoder, BertGenerationEncoder
@@ -156,6 +158,64 @@ class BertGenerationEncoderTester:
)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
def create_and_check_decoder_model_past_large_inputs(
self,
config,
input_ids,
input_mask,
token_labels,
encoder_hidden_states,
encoder_attention_mask,
**kwargs,
):
config.is_decoder = True
config.add_cross_attention = True
model = BertGenerationDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(
input_ids,
attention_mask=input_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=True,
)
past_key_values = outputs.past_key_values
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
next_mask = ids_tensor((self.batch_size, 3), vocab_size=2)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
output_from_no_past = model(
next_input_ids,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=True,
)["hidden_states"][0]
output_from_past = model(
next_tokens,
attention_mask=next_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
output_hidden_states=True,
)["hidden_states"][0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_for_causal_lm(
self,
config,
@@ -203,6 +263,10 @@ class BertGenerationEncoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(