Fix cache for GPT-Neo-X (#17764)

* Fix cache for GPT-Neo-X

* Add more tests
This commit is contained in:
Sylvain Gugger
2022-06-20 08:43:36 -04:00
committed by GitHub
parent a2d34b7c04
commit fdb120805c
3 changed files with 10 additions and 2 deletions

View File

@@ -218,6 +218,14 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
def test_decoder_model_past_large_inputs(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
def test_model_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: