[Reformer] - Cache hidden states and buckets to speed up inference (#5578)

* fix merge rebase

* add intermediate reformer code

* save intermediate caching results

* save intermediate

* save intermediate results

* save intermediate

* upload next step

* fix generate tests

* make tests work

* add named tuple output

* Apply suggestions from code review

* fix use_cache for False case

* fix tensor to gpu

* fix tensor to gpu

* refactor

* refactor and make style
This commit is contained in:
Patrick von Platen
2020-07-17 16:17:42 +02:00
committed by GitHub
parent 0b6c255a95
commit 9d37c56bab
3 changed files with 685 additions and 100 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetLMHeadModelOutput(ModelOutput): class XLNetLMHeadModelOutput(ModelOutput):
""" """
Output type of :class:`~transformers.XLNetModel`. Output type of :class:`~transformers.XLNetLMHeadModel`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided)
@@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput):
@dataclass @dataclass
class XLNetForSequenceClassificationOutput(ModelOutput): class XLNetForSequenceClassificationOutput(ModelOutput):
""" """
Base class for outputs of sentence classification models. Output type of :class:`~transformers.XLNetForSequenceClassification`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
@@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput):
@dataclass @dataclass
class XLNetForTokenClassificationOutput(ModelOutput): class XLNetForTokenClassificationOutput(ModelOutput):
""" """
Base class for outputs of token classification models. Output type of :class:`~transformers.XLNetForTokenClassificationOutput`.
Args: Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :

View File

@@ -181,8 +181,8 @@ class ReformerModelTester:
model = ReformerModel(config=config) model = ReformerModel(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
(sequence_output,) = model(input_ids, attention_mask=input_mask) sequence_output, _ = model(input_ids, attention_mask=input_mask)
(sequence_output,) = model(input_ids) sequence_output, _ = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
@@ -193,17 +193,21 @@ class ReformerModelTester:
) )
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
model = ReformerModelWithLMHead(config=config) config.is_decoder = False
config.lsh_num_chunks_after = 1
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
loss.backward() loss.backward()
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
config.lsh_num_chunks_after = 0
config.is_decoder = True
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
result = { result = {
"loss": loss, "loss": loss,
"prediction_scores": prediction_scores, "prediction_scores": prediction_scores,
@@ -332,9 +336,11 @@ class ReformerModelTester:
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.local_attention_probs_dropout_prob = 0 config.local_attention_probs_dropout_prob = 0
config.lsh_attention_probs_dropout_prob = 0 config.lsh_attention_probs_dropout_prob = 0
config.lsh_num_chunks_after = 1
config.is_decoder = False
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
@@ -348,7 +354,7 @@ class ReformerModelTester:
config.chunk_size_feed_forward = 1 config.chunk_size_feed_forward = 1
torch.manual_seed(0) torch.manual_seed(0)
model = ReformerModelWithLMHead(config=config) model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
model.zero_grad() model.zero_grad()
@@ -405,7 +411,22 @@ class ReformerModelTester:
output = model(input_ids, attention_mask=input_mask)[0] output = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item()) self.parent.assertFalse(torch.isnan(output).any().item())
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
config.bos_token_id = 0
config.eos_token_id = None
config.max_length = 20
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
output = model.generate()
self.parent.assertIsNotNone(output)
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config) model = ReformerModelWithLMHead(config=config)
model.to(torch_device) model.to(torch_device)
model.half() model.half()
@@ -418,13 +439,15 @@ class ReformerModelTester:
# force chunk length to be bigger than input_ids # force chunk length to be bigger than input_ids
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
config.local_attn_chunk_length = 2 * input_ids.shape[-1] config.local_attn_chunk_length = 2 * input_ids.shape[-1]
model = ReformerModelWithLMHead(config=config) config.lsh_num_chunks_after = 1
config.is_decoder = False
model = ReformerForMaskedLM(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
output_logits = model(input_ids, attention_mask=input_mask)[0] output_logits = model(input_ids, attention_mask=input_mask)[0]
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
model = ReformerForQuestionAnswering(config=config) model = ReformerForQuestionAnswering(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
@@ -440,6 +463,33 @@ class ReformerModelTester:
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result) self.check_loss_output(result)
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
config.is_decoder = True
config.lsh_num_chunks_before = 1
config.lsh_num_chunks_after = 0
model = ReformerModelWithLMHead(config=config)
model.to(torch_device)
model.eval()
input_ids_first = input_ids[:, :-1]
input_ids_second = input_ids[:, -1:]
# return saved cache
_, past_buckets_states = model(input_ids_first, use_cache=True)
# calculate last output with and without cache
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
outputs_without_cache = model(input_ids)[0][:, -1]
# select random slice idx
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
# outputs should be similar within range
self.parent.assertTrue(
torch.allclose(
outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2
)
)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, choice_labels) = config_and_inputs (config, input_ids, input_mask, choice_labels) = config_and_inputs
@@ -509,6 +559,18 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs) self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
def test_reformer_qa_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs)
def test_reformer_cached_inference(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_past_buckets_states(*config_and_inputs)
def test_reformer_cached_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs)
@slow @slow
def test_dropout_random_seed_is_changing(self): def test_dropout_random_seed_is_changing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"num_buckets": 2, "num_buckets": 2,
"num_hashes": 4, "num_hashes": 4,
"lsh_attn_chunk_length": 4, "lsh_attn_chunk_length": 4,
"lsh_num_chunks_before": 2, "lsh_num_chunks_before": 1,
"lsh_num_chunks_after": 3, "lsh_num_chunks_after": 0,
"chunk_size_lm_head": 5, "chunk_size_lm_head": 5,
"chunk_size_feed_forward": 6, "chunk_size_feed_forward": 6,
"feed_forward_size": 32, "feed_forward_size": 32,
@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
"axial_pos_embds": True, "axial_pos_embds": True,
"axial_pos_shape": [4, 8], "axial_pos_shape": [4, 8],
"axial_pos_embds_dim": [16, 48], "axial_pos_embds_dim": [16, 48],
"attn_layers": ["lsh", "lsh", "lsh", "lsh"], # sanotheu
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
"attn_layers": ["lsh"],
"pad_token_id": 0, "pad_token_id": 0,
"eos_token_id": 2, "eos_token_id": 2,
"scope": None, "scope": None,
@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase):
output_ids = model.generate( output_ids = model.generate(
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
) )
output_text = tokenizer.decode(output_ids[0]) output = tokenizer.decode(output_ids[0])
self.assertEqual( self.assertEqual(
output_text, output,
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst", "A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
) )
@slow
def test_pretrained_generate_use_cache_equality(self):
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
model.eval()
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False)
output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True)
output_with_cache = tokenizer.decode(output_ids_with_cache[0])
output_without_cache = tokenizer.decode(output_ids_without_cache[0])
self.assertEqual(output_with_cache, output_without_cache)