[Reformer] - Cache hidden states and buckets to speed up inference (#5578)
* fix merge rebase * add intermediate reformer code * save intermediate caching results * save intermediate * save intermediate results * save intermediate * upload next step * fix generate tests * make tests work * add named tuple output * Apply suggestions from code review * fix use_cache for False case * fix tensor to gpu * fix tensor to gpu * refactor * refactor and make style
This commit is contained in:
committed by
GitHub
parent
0b6c255a95
commit
9d37c56bab
@@ -181,8 +181,8 @@ class ReformerModelTester:
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
(sequence_output,) = model(input_ids, attention_mask=input_mask)
|
||||
(sequence_output,) = model(input_ids)
|
||||
sequence_output, _ = model(input_ids, attention_mask=input_mask)
|
||||
sequence_output, _ = model(input_ids)
|
||||
|
||||
result = {
|
||||
"sequence_output": sequence_output,
|
||||
@@ -193,17 +193,21 @@ class ReformerModelTester:
|
||||
)
|
||||
|
||||
def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
config.is_decoder = False
|
||||
config.lsh_num_chunks_after = 1
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0]
|
||||
loss.backward()
|
||||
|
||||
def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels):
|
||||
config.lsh_num_chunks_after = 0
|
||||
config.is_decoder = True
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"prediction_scores": prediction_scores,
|
||||
@@ -332,9 +336,11 @@ class ReformerModelTester:
|
||||
config.hidden_dropout_prob = 0
|
||||
config.local_attention_probs_dropout_prob = 0
|
||||
config.lsh_attention_probs_dropout_prob = 0
|
||||
config.lsh_num_chunks_after = 1
|
||||
config.is_decoder = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
model.zero_grad()
|
||||
@@ -348,7 +354,7 @@ class ReformerModelTester:
|
||||
config.chunk_size_feed_forward = 1
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
model.zero_grad()
|
||||
@@ -405,7 +411,22 @@ class ReformerModelTester:
|
||||
output = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertFalse(torch.isnan(output).any().item())
|
||||
|
||||
def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
config.is_decoder = True
|
||||
config.lsh_num_chunks_after = 0
|
||||
config.bos_token_id = 0
|
||||
config.eos_token_id = None
|
||||
config.max_length = 20
|
||||
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output = model.generate()
|
||||
self.parent.assertIsNotNone(output)
|
||||
|
||||
def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels):
|
||||
config.is_decoder = True
|
||||
config.lsh_num_chunks_after = 0
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.half()
|
||||
@@ -418,13 +439,15 @@ class ReformerModelTester:
|
||||
# force chunk length to be bigger than input_ids
|
||||
config.lsh_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
config.local_attn_chunk_length = 2 * input_ids.shape[-1]
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
config.lsh_num_chunks_after = 1
|
||||
config.is_decoder = False
|
||||
model = ReformerForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
output_logits = model(input_ids, attention_mask=input_mask)[0]
|
||||
self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -440,6 +463,33 @@ class ReformerModelTester:
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels):
|
||||
config.is_decoder = True
|
||||
config.lsh_num_chunks_before = 1
|
||||
config.lsh_num_chunks_after = 0
|
||||
model = ReformerModelWithLMHead(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
input_ids_first = input_ids[:, :-1]
|
||||
input_ids_second = input_ids[:, -1:]
|
||||
|
||||
# return saved cache
|
||||
_, past_buckets_states = model(input_ids_first, use_cache=True)
|
||||
|
||||
# calculate last output with and without cache
|
||||
outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True)
|
||||
outputs_without_cache = model(input_ids)[0][:, -1]
|
||||
|
||||
# select random slice idx
|
||||
random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item()
|
||||
|
||||
# outputs should be similar within range
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(config, input_ids, input_mask, choice_labels) = config_and_inputs
|
||||
@@ -509,6 +559,18 @@ class ReformerTesterMixin:
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs)
|
||||
|
||||
def test_reformer_qa_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_reformer_cached_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_past_buckets_states(*config_and_inputs)
|
||||
|
||||
def test_reformer_cached_generate(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_dropout_random_seed_is_changing(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
@@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
"num_buckets": 2,
|
||||
"num_hashes": 4,
|
||||
"lsh_attn_chunk_length": 4,
|
||||
"lsh_num_chunks_before": 2,
|
||||
"lsh_num_chunks_after": 3,
|
||||
"lsh_num_chunks_before": 1,
|
||||
"lsh_num_chunks_after": 0,
|
||||
"chunk_size_lm_head": 5,
|
||||
"chunk_size_feed_forward": 6,
|
||||
"feed_forward_size": 32,
|
||||
@@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T
|
||||
"axial_pos_embds": True,
|
||||
"axial_pos_shape": [4, 8],
|
||||
"axial_pos_embds_dim": [16, 48],
|
||||
"attn_layers": ["lsh", "lsh", "lsh", "lsh"],
|
||||
# sanotheu
|
||||
# "attn_layers": ["lsh", "lsh", "lsh", "lsh"],
|
||||
"attn_layers": ["lsh"],
|
||||
"pad_token_id": 0,
|
||||
"eos_token_id": 2,
|
||||
"scope": None,
|
||||
@@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase):
|
||||
output_ids = model.generate(
|
||||
input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8
|
||||
)
|
||||
output_text = tokenizer.decode(output_ids[0])
|
||||
output = tokenizer.decode(output_ids[0])
|
||||
|
||||
self.assertEqual(
|
||||
output_text,
|
||||
output,
|
||||
"A few months later state expression in his ideas, at the first entrance. He was positively for an inst",
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_pretrained_generate_use_cache_equality(self):
|
||||
model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device)
|
||||
tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment")
|
||||
model.eval()
|
||||
input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device)
|
||||
output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False)
|
||||
output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True)
|
||||
|
||||
output_with_cache = tokenizer.decode(output_ids_with_cache[0])
|
||||
output_without_cache = tokenizer.decode(output_ids_without_cache[0])
|
||||
|
||||
self.assertEqual(output_with_cache, output_without_cache)
|
||||
|
||||
Reference in New Issue
Block a user