[Config, Caching] Remove output_past everywhere and replace by use_cache argument (#3734)
* remove output_past from pt * make style * add optional input length for gpt2 * add use cache to prepare input * save memory in gpt2 * correct gpt2 test inputs * make past input optional for gpt2 * finish use_cache for all models * make style * delete modeling_gpt2 change in test file * correct docstring * correct is true statements for gpt2
This commit is contained in:
committed by
GitHub
parent
092cf881a5
commit
01c37dcdb5
@@ -128,7 +128,6 @@ class ModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
config.output_attentions = True
|
||||
config.output_hidden_states = False
|
||||
config.output_past = False
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -664,10 +663,6 @@ class ModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
# needed for Bart beam search
|
||||
config.output_past = True
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids)
|
||||
output, past_key_value_states = model(input_ids, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -235,8 +235,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# append to next input_ids and
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
|
||||
output_from_no_past, _ = model(next_input_ids)
|
||||
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states)
|
||||
output_from_no_past = model(next_input_ids)[0]
|
||||
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
@@ -260,7 +260,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
attn_mask[:, half_seq_length:] = 0
|
||||
|
||||
# first forward pass
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask)
|
||||
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
|
||||
|
||||
# create hypothetical next token and extent to next_input_ids
|
||||
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
|
||||
@@ -277,10 +277,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
|
||||
output_from_past, _ = model(
|
||||
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
|
||||
output_from_past = model(
|
||||
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
|
||||
)
|
||||
)[0]
|
||||
|
||||
# select random slice
|
||||
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
|
||||
@@ -298,10 +298,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
torch.manual_seed(0)
|
||||
model.set_output_past(False)
|
||||
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
output_without_past_cache = model.generate(
|
||||
input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
model.set_output_past(True)
|
||||
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
|
||||
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
|
||||
|
||||
@@ -321,6 +321,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"use_cache": False,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@@ -460,10 +460,6 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
# needed for Bart beam search
|
||||
config.output_past = True
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user